Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

AST Optimizations

Before we generate LLVM code, we can make the AST better. By “better”, we mean simpler expressions that do the same thing. For example, 1 + 2 * 3 can become just 7. This is called optimization.

In this chapter, we introduce the visitor pattern, a classic technique for walking through and transforming ASTs.

The Visitor Pattern

Imagine we want to add a new operation on our AST, like “pretty print” or “count variables” or “simplify expressions”. Without a good design, we would need to modify every AST node type:

// Bad: scattered across many files
impl Expr {
    fn pretty_print(&self) { ... }
    fn count_variables(&self) { ... }
    fn simplify(&self) { ... }
}

Every time we add a new operation, we touch every node. That gets messy.

The visitor pattern flips this around. Instead of adding methods to nodes, we create separate visitor objects:

struct PrettyPrinter;
impl ExprVisitor for PrettyPrinter { ... }

struct ConstantFolder;
impl ExprVisitor for ConstantFolder { ... }

Each visitor is self-contained. Adding a new operation means adding a new visitor, not modifying existing code. This follows the open/closed principle: open for extension, closed for modification.

The ExprVisitor Trait

Our visitor trait provides a hook for each expression type:

/// Visitor trait for traversing typed expressions
///
/// Each visit method returns the transformed expression.
/// Default implementations traverse children recursively.
pub trait ExprVisitor {
    /// Visit any expression - dispatches to specific visit methods
    fn visit_expr(&mut self, expr: &TypedExpr) -> TypedExpr {
        let new_expr = match &expr.expr {
            Expr::Int(n) => self.visit_int(*n),
            Expr::Bool(b) => self.visit_bool(*b),
            Expr::Var(name) => self.visit_var(name),
            Expr::Unary { op, expr: inner } => self.visit_unary(*op, inner),
            Expr::Binary { op, left, right } => self.visit_binary(*op, left, right),
            Expr::Call { name, args } => self.visit_call(name, args),
            Expr::If {
                cond,
                then_branch,
                else_branch,
            } => self.visit_if(cond, then_branch, else_branch),
            Expr::While { cond, body } => self.visit_while(cond, body),
            Expr::Block(stmts) => self.visit_block(stmts),
        };
        TypedExpr {
            expr: new_expr,
            ty: expr.ty.clone(),
        }
    }

    fn visit_int(&mut self, n: i64) -> Expr {
        Expr::Int(n)
    }

    fn visit_bool(&mut self, b: bool) -> Expr {
        Expr::Bool(b)
    }

    fn visit_var(&mut self, name: &str) -> Expr {
        Expr::Var(name.to_string())
    }

    fn visit_unary(&mut self, op: UnaryOp, expr: &TypedExpr) -> Expr {
        let visited = self.visit_expr(expr);
        Expr::Unary {
            op,
            expr: Box::new(visited),
        }
    }

    fn visit_binary(&mut self, op: BinaryOp, left: &TypedExpr, right: &TypedExpr) -> Expr {
        let l = self.visit_expr(left);
        let r = self.visit_expr(right);
        Expr::Binary {
            op,
            left: Box::new(l),
            right: Box::new(r),
        }
    }

    fn visit_call(&mut self, name: &str, args: &[TypedExpr]) -> Expr {
        let visited_args: Vec<TypedExpr> = args.iter().map(|a| self.visit_expr(a)).collect();
        Expr::Call {
            name: name.to_string(),
            args: visited_args,
        }
    }

    fn visit_if(&mut self, cond: &TypedExpr, then_branch: &[Stmt], else_branch: &[Stmt]) -> Expr {
        let visited_cond = self.visit_expr(cond);
        let visited_then: Vec<Stmt> = then_branch.iter().map(|s| self.visit_stmt(s)).collect();
        let visited_else: Vec<Stmt> = else_branch.iter().map(|s| self.visit_stmt(s)).collect();
        Expr::If {
            cond: Box::new(visited_cond),
            then_branch: visited_then,
            else_branch: visited_else,
        }
    }

    fn visit_while(&mut self, cond: &TypedExpr, body: &[Stmt]) -> Expr {
        let visited_cond = self.visit_expr(cond);
        let visited_body: Vec<Stmt> = body.iter().map(|s| self.visit_stmt(s)).collect();
        Expr::While {
            cond: Box::new(visited_cond),
            body: visited_body,
        }
    }

    fn visit_block(&mut self, stmts: &[Stmt]) -> Expr {
        let visited: Vec<Stmt> = stmts.iter().map(|s| self.visit_stmt(s)).collect();
        Expr::Block(visited)
    }

    /// Visit a statement
    fn visit_stmt(&mut self, stmt: &Stmt) -> Stmt {
        match stmt {
            Stmt::Function {
                name,
                params,
                return_type,
                body,
            } => {
                let visited_body: Vec<Stmt> = body.iter().map(|s| self.visit_stmt(s)).collect();
                Stmt::Function {
                    name: name.clone(),
                    params: params.clone(),
                    return_type: return_type.clone(),
                    body: visited_body,
                }
            }
            Stmt::Return(expr) => Stmt::Return(self.visit_expr(expr)),
            Stmt::Assignment {
                name,
                type_ann,
                value,
            } => Stmt::Assignment {
                name: name.clone(),
                type_ann: type_ann.clone(),
                value: self.visit_expr(value),
            },
            Stmt::Expr(expr) => Stmt::Expr(self.visit_expr(expr)),
        }
    }
}

secondlang/src/visitor.rs

Let us understand how this works:

  1. visit_expr is the entry point. It looks at the expression type and calls the appropriate visitor method.

  2. Each visit_* method has a default implementation that just recurses into children. For example, visit_binary visits left and right, then rebuilds the binary expression.

  3. To customize behavior, we override the methods we care about. For constant folding, we override visit_binary to check if both operands are constants.

This is sometimes called a tree walk or tree traversal.

Optimization 1: Constant Folding

Constant folding evaluates expressions where all values are known at compile time:

$$ \begin{aligned} 1 + 2 \times 3 &\Rightarrow 7 \\ 5 < 10 &\Rightarrow \text{true} \\ -(-42) &\Rightarrow 42 \end{aligned} $$

Why wait until runtime to compute 1 + 2 when we can do it now?

Here is the pseudocode:

FUNCTION constant_fold(expr):
   case expr of:
   | Binary(op, left, right):
       folded_left = constant_fold(left)
       folded_right = constant_fold(right)

       if both folded_left and folded_right are constants:
           compute the result at compile time
           return the constant result
       else:
           return Binary(op, folded_left, folded_right)

   | Unary(op, inner):
       folded = constant_fold(inner)
       if folded is a constant:
           compute result
           return constant
       else:
           return Unary(op, folded)

   | other:
       return expr  # cannot fold

And the implementation:

/// Folds constant expressions: `1 + 2` becomes `3`
///
/// This is a simple optimization that evaluates expressions
/// where all operands are known at compile time.
pub struct ConstantFolder;

impl ConstantFolder {
    pub fn new() -> Self {
        ConstantFolder
    }

    pub fn fold_program(stmts: &[Stmt]) -> Vec<Stmt> {
        let mut folder = ConstantFolder::new();
        stmts.iter().map(|s| folder.visit_stmt(s)).collect()
    }
}

impl Default for ConstantFolder {
    fn default() -> Self {
        Self::new()
    }
}

impl ExprVisitor for ConstantFolder {
    fn visit_binary(&mut self, op: BinaryOp, left: &TypedExpr, right: &TypedExpr) -> Expr {
        // First, recursively fold children
        let l = self.visit_expr(left);
        let r = self.visit_expr(right);

        // Try to fold if both are constants
        if let (Expr::Int(lv), Expr::Int(rv)) = (&l.expr, &r.expr) {
            let result = match op {
                BinaryOp::Add => Some(lv + rv),
                BinaryOp::Sub => Some(lv - rv),
                BinaryOp::Mul => Some(lv * rv),
                BinaryOp::Div if *rv != 0 => Some(lv / rv),
                BinaryOp::Mod if *rv != 0 => Some(lv % rv),
                _ => None,
            };
            if let Some(val) = result {
                return Expr::Int(val);
            }
        }

        // Try boolean constant folding for comparisons
        if let (Expr::Int(lv), Expr::Int(rv)) = (&l.expr, &r.expr) {
            let result = match op {
                BinaryOp::Lt => Some(*lv < *rv),
                BinaryOp::Gt => Some(*lv > *rv),
                BinaryOp::Le => Some(*lv <= *rv),
                BinaryOp::Ge => Some(*lv >= *rv),
                BinaryOp::Eq => Some(*lv == *rv),
                BinaryOp::Ne => Some(*lv != *rv),
                _ => None,
            };
            if let Some(val) = result {
                return Expr::Bool(val);
            }
        }

        // Can't fold, return as-is
        Expr::Binary {
            op,
            left: Box::new(l),
            right: Box::new(r),
        }
    }

    fn visit_unary(&mut self, op: UnaryOp, expr: &TypedExpr) -> Expr {
        let e = self.visit_expr(expr);

        match (&op, &e.expr) {
            (UnaryOp::Neg, Expr::Int(n)) => Expr::Int(-n),
            (UnaryOp::Not, Expr::Bool(b)) => Expr::Bool(!b),
            _ => Expr::Unary {
                op,
                expr: Box::new(e),
            },
        }
    }
}

secondlang/src/visitor.rs

Let us trace through 1 + 2 * 3:

  1. Visit outer + expression
  2. Recursively visit left (1) → returns Int(1)
  3. Recursively visit right (2 * 3) → visits *, finds Int(2) and Int(3), returns Int(6)
  4. Back at +: left is Int(1), right is Int(6) → return Int(7)

The whole expression becomes just 7.

Optimization 2: Algebraic Simplification

Algebraic simplification (also called strength reduction) applies mathematical identities:

ExpressionSimplifiedIdentity Applied
x + 0xAdditive identity
x - 0xAdditive identity
x * 00Zero property
x * 1xMultiplicative identity
x / 1xMultiplicative identity
0 + xxCommutativity + identity
1 * xxCommutativity + identity

These transformations are always valid and can save runtime computation.

Pseudocode:

FUNCTION simplify(expr):
   case expr of:
   | Binary(Add, x, Int(0)): return simplify(x)
   | Binary(Add, Int(0), x): return simplify(x)
   | Binary(Mul, x, Int(1)): return simplify(x)
   | Binary(Mul, Int(1), x): return simplify(x)
   | Binary(Mul, _, Int(0)): return Int(0)
   | Binary(Mul, Int(0), _): return Int(0)
   | ... # other cases
   | Binary(op, left, right):
       return Binary(op, simplify(left), simplify(right))
   | other:
       return expr

Implementation:

/// Applies algebraic simplifications:
/// - `x + 0` → `x`
/// - `x - 0` → `x`
/// - `x * 0` → `0`
/// - `x * 1` → `x`
/// - `x / 1` → `x`
/// - `0 + x` → `x`
/// - `1 * x` → `x`
/// - `0 * x` → `0`
pub struct AlgebraicSimplifier;

impl AlgebraicSimplifier {
    pub fn new() -> Self {
        AlgebraicSimplifier
    }

    pub fn simplify_program(stmts: &[Stmt]) -> Vec<Stmt> {
        let mut simplifier = AlgebraicSimplifier::new();
        stmts.iter().map(|s| simplifier.visit_stmt(s)).collect()
    }
}

impl Default for AlgebraicSimplifier {
    fn default() -> Self {
        Self::new()
    }
}

impl ExprVisitor for AlgebraicSimplifier {
    fn visit_binary(&mut self, op: BinaryOp, left: &TypedExpr, right: &TypedExpr) -> Expr {
        // First, recursively simplify children
        let l = self.visit_expr(left);
        let r = self.visit_expr(right);

        // Apply algebraic identities
        match (&op, &l.expr, &r.expr) {
            // x + 0 = x
            (BinaryOp::Add, _, Expr::Int(0)) => return l.expr,
            // 0 + x = x
            (BinaryOp::Add, Expr::Int(0), _) => return r.expr,
            // x - 0 = x
            (BinaryOp::Sub, _, Expr::Int(0)) => return l.expr,
            // x * 0 = 0
            (BinaryOp::Mul, _, Expr::Int(0)) => return Expr::Int(0),
            // 0 * x = 0
            (BinaryOp::Mul, Expr::Int(0), _) => return Expr::Int(0),
            // x * 1 = x
            (BinaryOp::Mul, _, Expr::Int(1)) => return l.expr,
            // 1 * x = x
            (BinaryOp::Mul, Expr::Int(1), _) => return r.expr,
            // x / 1 = x
            (BinaryOp::Div, _, Expr::Int(1)) => return l.expr,
            _ => {}
        }

        Expr::Binary {
            op,
            left: Box::new(l),
            right: Box::new(r),
        }
    }
}

secondlang/src/visitor.rs

Chaining Optimizations

Multiple optimization passes can be chained. This is called a pass pipeline:

pub fn optimize_program(program: &Program) -> Program {
    // First: fold constants
    let program = ConstantFolder::fold_program(&program);
    // Then: simplify algebra
    AlgebraicSimplifier::simplify_program(&program)
}

Consider x * (1 + 0):

  1. After constant folding: x * 1 (because 1 + 0 = 1)
  2. After algebraic simplification: x (because x * 1 = x)

Two passes, significant simplification. The order matters - constant folding first creates opportunities for algebraic simplification.

Why Bother?

You might wonder: “LLVM will optimize this anyway. Why do it ourselves?”

Good question. LLVM will do these optimizations. But:

  1. Learning: Implementing optimizations helps you understand how compilers work. These are the same techniques used in production compilers.

  2. Simplicity: Simpler AST means simpler code generation. Less can go wrong.

  3. Debug output: When you print the AST for debugging, optimized code is easier to read.

  4. Specialized optimizations: You might know things about your language that LLVM does not. Custom optimizations can exploit that knowledge.

  5. Compile time: Simpler AST means less work for LLVM, which means faster compilation.

Other Common Optimizations

Production compilers do many more optimizations:

OptimizationWhat it does
Dead code eliminationRemove unreachable code
Common subexpression eliminationCompute x * y once if used twice
Loop unrollingReplace loops with repeated code
InliningReplace function calls with function bodies
Tail call optimizationTurn tail recursion into loops

We leave these as exercises. The visitor pattern makes adding new optimizations straightforward.

Using the Optimizations

Enable optimizations with the -O flag:

# Without optimization
rustup run nightly cargo run -- --ir examples/fibonacci.sl

# With optimization
rustup run nightly cargo run -- --ir -O examples/fibonacci.sl

Testing

rustup run nightly cargo test

In the next chapter, we look at what LLVM IR looks like before we start generating it.