AST Optimizations
Before we generate LLVM code, we can make the AST better. By “better”, we mean simpler expressions that do the same thing. For example, 1 + 2 * 3 can become just 7. This is called optimization.
In this chapter, we introduce the visitor pattern, a classic technique for walking through and transforming ASTs.
The Visitor Pattern
Imagine we want to add a new operation on our AST, like “pretty print” or “count variables” or “simplify expressions”. Without a good design, we would need to modify every AST node type:
// Bad: scattered across many files
impl Expr {
fn pretty_print(&self) { ... }
fn count_variables(&self) { ... }
fn simplify(&self) { ... }
}
Every time we add a new operation, we touch every node. That gets messy.
The visitor pattern flips this around. Instead of adding methods to nodes, we create separate visitor objects:
struct PrettyPrinter;
impl ExprVisitor for PrettyPrinter { ... }
struct ConstantFolder;
impl ExprVisitor for ConstantFolder { ... }
Each visitor is self-contained. Adding a new operation means adding a new visitor, not modifying existing code. This follows the open/closed principle: open for extension, closed for modification.
The ExprVisitor Trait
Our visitor trait provides a hook for each expression type:
/// Visitor trait for traversing typed expressions
///
/// Each visit method returns the transformed expression.
/// Default implementations traverse children recursively.
pub trait ExprVisitor {
/// Visit any expression - dispatches to specific visit methods
fn visit_expr(&mut self, expr: &TypedExpr) -> TypedExpr {
let new_expr = match &expr.expr {
Expr::Int(n) => self.visit_int(*n),
Expr::Bool(b) => self.visit_bool(*b),
Expr::Var(name) => self.visit_var(name),
Expr::Unary { op, expr: inner } => self.visit_unary(*op, inner),
Expr::Binary { op, left, right } => self.visit_binary(*op, left, right),
Expr::Call { name, args } => self.visit_call(name, args),
Expr::If {
cond,
then_branch,
else_branch,
} => self.visit_if(cond, then_branch, else_branch),
Expr::While { cond, body } => self.visit_while(cond, body),
Expr::Block(stmts) => self.visit_block(stmts),
};
TypedExpr {
expr: new_expr,
ty: expr.ty.clone(),
}
}
fn visit_int(&mut self, n: i64) -> Expr {
Expr::Int(n)
}
fn visit_bool(&mut self, b: bool) -> Expr {
Expr::Bool(b)
}
fn visit_var(&mut self, name: &str) -> Expr {
Expr::Var(name.to_string())
}
fn visit_unary(&mut self, op: UnaryOp, expr: &TypedExpr) -> Expr {
let visited = self.visit_expr(expr);
Expr::Unary {
op,
expr: Box::new(visited),
}
}
fn visit_binary(&mut self, op: BinaryOp, left: &TypedExpr, right: &TypedExpr) -> Expr {
let l = self.visit_expr(left);
let r = self.visit_expr(right);
Expr::Binary {
op,
left: Box::new(l),
right: Box::new(r),
}
}
fn visit_call(&mut self, name: &str, args: &[TypedExpr]) -> Expr {
let visited_args: Vec<TypedExpr> = args.iter().map(|a| self.visit_expr(a)).collect();
Expr::Call {
name: name.to_string(),
args: visited_args,
}
}
fn visit_if(&mut self, cond: &TypedExpr, then_branch: &[Stmt], else_branch: &[Stmt]) -> Expr {
let visited_cond = self.visit_expr(cond);
let visited_then: Vec<Stmt> = then_branch.iter().map(|s| self.visit_stmt(s)).collect();
let visited_else: Vec<Stmt> = else_branch.iter().map(|s| self.visit_stmt(s)).collect();
Expr::If {
cond: Box::new(visited_cond),
then_branch: visited_then,
else_branch: visited_else,
}
}
fn visit_while(&mut self, cond: &TypedExpr, body: &[Stmt]) -> Expr {
let visited_cond = self.visit_expr(cond);
let visited_body: Vec<Stmt> = body.iter().map(|s| self.visit_stmt(s)).collect();
Expr::While {
cond: Box::new(visited_cond),
body: visited_body,
}
}
fn visit_block(&mut self, stmts: &[Stmt]) -> Expr {
let visited: Vec<Stmt> = stmts.iter().map(|s| self.visit_stmt(s)).collect();
Expr::Block(visited)
}
/// Visit a statement
fn visit_stmt(&mut self, stmt: &Stmt) -> Stmt {
match stmt {
Stmt::Function {
name,
params,
return_type,
body,
} => {
let visited_body: Vec<Stmt> = body.iter().map(|s| self.visit_stmt(s)).collect();
Stmt::Function {
name: name.clone(),
params: params.clone(),
return_type: return_type.clone(),
body: visited_body,
}
}
Stmt::Return(expr) => Stmt::Return(self.visit_expr(expr)),
Stmt::Assignment {
name,
type_ann,
value,
} => Stmt::Assignment {
name: name.clone(),
type_ann: type_ann.clone(),
value: self.visit_expr(value),
},
Stmt::Expr(expr) => Stmt::Expr(self.visit_expr(expr)),
}
}
}
Let us understand how this works:
-
visit_expris the entry point. It looks at the expression type and calls the appropriate visitor method. -
Each
visit_*method has a default implementation that just recurses into children. For example,visit_binaryvisits left and right, then rebuilds the binary expression. -
To customize behavior, we override the methods we care about. For constant folding, we override
visit_binaryto check if both operands are constants.
This is sometimes called a tree walk or tree traversal.
Optimization 1: Constant Folding
Constant folding evaluates expressions where all values are known at compile time:
$$ \begin{aligned} 1 + 2 \times 3 &\Rightarrow 7 \\ 5 < 10 &\Rightarrow \text{true} \\ -(-42) &\Rightarrow 42 \end{aligned} $$
Why wait until runtime to compute 1 + 2 when we can do it now?
Here is the pseudocode:
FUNCTION constant_fold(expr):
case expr of:
| Binary(op, left, right):
folded_left = constant_fold(left)
folded_right = constant_fold(right)
if both folded_left and folded_right are constants:
compute the result at compile time
return the constant result
else:
return Binary(op, folded_left, folded_right)
| Unary(op, inner):
folded = constant_fold(inner)
if folded is a constant:
compute result
return constant
else:
return Unary(op, folded)
| other:
return expr # cannot fold
And the implementation:
/// Folds constant expressions: `1 + 2` becomes `3`
///
/// This is a simple optimization that evaluates expressions
/// where all operands are known at compile time.
pub struct ConstantFolder;
impl ConstantFolder {
pub fn new() -> Self {
ConstantFolder
}
pub fn fold_program(stmts: &[Stmt]) -> Vec<Stmt> {
let mut folder = ConstantFolder::new();
stmts.iter().map(|s| folder.visit_stmt(s)).collect()
}
}
impl Default for ConstantFolder {
fn default() -> Self {
Self::new()
}
}
impl ExprVisitor for ConstantFolder {
fn visit_binary(&mut self, op: BinaryOp, left: &TypedExpr, right: &TypedExpr) -> Expr {
// First, recursively fold children
let l = self.visit_expr(left);
let r = self.visit_expr(right);
// Try to fold if both are constants
if let (Expr::Int(lv), Expr::Int(rv)) = (&l.expr, &r.expr) {
let result = match op {
BinaryOp::Add => Some(lv + rv),
BinaryOp::Sub => Some(lv - rv),
BinaryOp::Mul => Some(lv * rv),
BinaryOp::Div if *rv != 0 => Some(lv / rv),
BinaryOp::Mod if *rv != 0 => Some(lv % rv),
_ => None,
};
if let Some(val) = result {
return Expr::Int(val);
}
}
// Try boolean constant folding for comparisons
if let (Expr::Int(lv), Expr::Int(rv)) = (&l.expr, &r.expr) {
let result = match op {
BinaryOp::Lt => Some(*lv < *rv),
BinaryOp::Gt => Some(*lv > *rv),
BinaryOp::Le => Some(*lv <= *rv),
BinaryOp::Ge => Some(*lv >= *rv),
BinaryOp::Eq => Some(*lv == *rv),
BinaryOp::Ne => Some(*lv != *rv),
_ => None,
};
if let Some(val) = result {
return Expr::Bool(val);
}
}
// Can't fold, return as-is
Expr::Binary {
op,
left: Box::new(l),
right: Box::new(r),
}
}
fn visit_unary(&mut self, op: UnaryOp, expr: &TypedExpr) -> Expr {
let e = self.visit_expr(expr);
match (&op, &e.expr) {
(UnaryOp::Neg, Expr::Int(n)) => Expr::Int(-n),
(UnaryOp::Not, Expr::Bool(b)) => Expr::Bool(!b),
_ => Expr::Unary {
op,
expr: Box::new(e),
},
}
}
}
Let us trace through 1 + 2 * 3:
- Visit outer
+expression - Recursively visit left (
1) → returnsInt(1) - Recursively visit right (
2 * 3) → visits*, findsInt(2)andInt(3), returnsInt(6) - Back at
+: left isInt(1), right isInt(6)→ returnInt(7)
The whole expression becomes just 7.
Optimization 2: Algebraic Simplification
Algebraic simplification (also called strength reduction) applies mathematical identities:
| Expression | Simplified | Identity Applied |
|---|---|---|
x + 0 | x | Additive identity |
x - 0 | x | Additive identity |
x * 0 | 0 | Zero property |
x * 1 | x | Multiplicative identity |
x / 1 | x | Multiplicative identity |
0 + x | x | Commutativity + identity |
1 * x | x | Commutativity + identity |
These transformations are always valid and can save runtime computation.
Pseudocode:
FUNCTION simplify(expr):
case expr of:
| Binary(Add, x, Int(0)): return simplify(x)
| Binary(Add, Int(0), x): return simplify(x)
| Binary(Mul, x, Int(1)): return simplify(x)
| Binary(Mul, Int(1), x): return simplify(x)
| Binary(Mul, _, Int(0)): return Int(0)
| Binary(Mul, Int(0), _): return Int(0)
| ... # other cases
| Binary(op, left, right):
return Binary(op, simplify(left), simplify(right))
| other:
return expr
Implementation:
/// Applies algebraic simplifications:
/// - `x + 0` → `x`
/// - `x - 0` → `x`
/// - `x * 0` → `0`
/// - `x * 1` → `x`
/// - `x / 1` → `x`
/// - `0 + x` → `x`
/// - `1 * x` → `x`
/// - `0 * x` → `0`
pub struct AlgebraicSimplifier;
impl AlgebraicSimplifier {
pub fn new() -> Self {
AlgebraicSimplifier
}
pub fn simplify_program(stmts: &[Stmt]) -> Vec<Stmt> {
let mut simplifier = AlgebraicSimplifier::new();
stmts.iter().map(|s| simplifier.visit_stmt(s)).collect()
}
}
impl Default for AlgebraicSimplifier {
fn default() -> Self {
Self::new()
}
}
impl ExprVisitor for AlgebraicSimplifier {
fn visit_binary(&mut self, op: BinaryOp, left: &TypedExpr, right: &TypedExpr) -> Expr {
// First, recursively simplify children
let l = self.visit_expr(left);
let r = self.visit_expr(right);
// Apply algebraic identities
match (&op, &l.expr, &r.expr) {
// x + 0 = x
(BinaryOp::Add, _, Expr::Int(0)) => return l.expr,
// 0 + x = x
(BinaryOp::Add, Expr::Int(0), _) => return r.expr,
// x - 0 = x
(BinaryOp::Sub, _, Expr::Int(0)) => return l.expr,
// x * 0 = 0
(BinaryOp::Mul, _, Expr::Int(0)) => return Expr::Int(0),
// 0 * x = 0
(BinaryOp::Mul, Expr::Int(0), _) => return Expr::Int(0),
// x * 1 = x
(BinaryOp::Mul, _, Expr::Int(1)) => return l.expr,
// 1 * x = x
(BinaryOp::Mul, Expr::Int(1), _) => return r.expr,
// x / 1 = x
(BinaryOp::Div, _, Expr::Int(1)) => return l.expr,
_ => {}
}
Expr::Binary {
op,
left: Box::new(l),
right: Box::new(r),
}
}
}
Chaining Optimizations
Multiple optimization passes can be chained. This is called a pass pipeline:
pub fn optimize_program(program: &Program) -> Program {
// First: fold constants
let program = ConstantFolder::fold_program(&program);
// Then: simplify algebra
AlgebraicSimplifier::simplify_program(&program)
}
Consider x * (1 + 0):
- After constant folding:
x * 1(because1 + 0 = 1) - After algebraic simplification:
x(becausex * 1 = x)
Two passes, significant simplification. The order matters - constant folding first creates opportunities for algebraic simplification.
Why Bother?
You might wonder: “LLVM will optimize this anyway. Why do it ourselves?”
Good question. LLVM will do these optimizations. But:
-
Learning: Implementing optimizations helps you understand how compilers work. These are the same techniques used in production compilers.
-
Simplicity: Simpler AST means simpler code generation. Less can go wrong.
-
Debug output: When you print the AST for debugging, optimized code is easier to read.
-
Specialized optimizations: You might know things about your language that LLVM does not. Custom optimizations can exploit that knowledge.
-
Compile time: Simpler AST means less work for LLVM, which means faster compilation.
Other Common Optimizations
Production compilers do many more optimizations:
| Optimization | What it does |
|---|---|
| Dead code elimination | Remove unreachable code |
| Common subexpression elimination | Compute x * y once if used twice |
| Loop unrolling | Replace loops with repeated code |
| Inlining | Replace function calls with function bodies |
| Tail call optimization | Turn tail recursion into loops |
We leave these as exercises. The visitor pattern makes adding new optimizations straightforward.
Using the Optimizations
Enable optimizations with the -O flag:
# Without optimization
rustup run nightly cargo run -- --ir examples/fibonacci.sl
# With optimization
rustup run nightly cargo run -- --ir -O examples/fibonacci.sl
Testing
rustup run nightly cargo test
In the next chapter, we look at what LLVM IR looks like before we start generating it.