AST Traversal Patterns

Recall from the previous section that JITing our add function was very detailed and cumbersome to write. Fortunately, there are some useful patterns for traversing complicated ASTs (and IRs)

  • Builder pattern
  • Visitor pattern (Will be introduced in chapter 4)

Builder Pattern

Recall how we have interpreted our AST by traversing recursively and evaluating the nodes

struct Eval;

impl Eval {
    pub fn new() -> Self {
        Self
    }
    pub fn eval(&self, node: &Node) -> i32 {
        match node {
            Node::Int(n) => *n,
            Node::UnaryExpr { op, child } => {
                let child = self.eval(child);
                match op {
                    Operator::Plus => child,
                    Operator::Minus => -child,
                }
            }
            Node::BinaryExpr { op, lhs, rhs } => {
                let lhs_ret = self.eval(lhs);
                let rhs_ret = self.eval(rhs);

                match op {
                    Operator::Plus => lhs_ret + rhs_ret,
                    Operator::Minus => lhs_ret - rhs_ret,
                }
            }
        }
    }
}

Filename: calculator/src/compiler/interpreter.rs

but instead, we can take advantage of the inkwell Builder and recursively traverse our Calc AST as follows

struct RecursiveBuilder<'a> {
    i32_type: IntType<'a>,
    builder: &'a Builder<'a>,
}

impl<'a> RecursiveBuilder<'a> {
    pub fn new(i32_type: IntType<'a>, builder: &'a Builder) -> Self {
        Self { i32_type, builder }
    }
    pub fn build(&self, ast: &Node) -> IntValue {
        match ast {
            Node::Int(n) => self.i32_type.const_int(*n as u64, true),
            Node::UnaryExpr { op, child } => {
                let child = self.build(child);
                match op {
                    Operator::Minus => child.const_neg(),
                    Operator::Plus => child,
                }
            }
            Node::BinaryExpr { op, lhs, rhs } => {
                let left = self.build(lhs);
                let right = self.build(rhs);

                match op {
                    Operator::Plus => self
                        .builder
                        .build_int_add(left, right, "plus_temp")
                        .unwrap(),
                    Operator::Minus => self
                        .builder
                        .build_int_sub(left, right, "minus_temp")
                        .unwrap(),
                }
            }
        }
    }
}

and similar to our addition example, we can JIT the builder output

pub struct Jit;

impl Compile for Jit {
    type Output = Result<i32>;

    fn from_ast(ast: Vec<Node>) -> Self::Output {
        let context = Context::create();
        let module = context.create_module("calculator");

        let builder = context.create_builder();

        let execution_engine = module
            .create_jit_execution_engine(OptimizationLevel::None)
            .unwrap();

        let i32_type = context.i32_type();
        let fn_type = i32_type.fn_type(&[], false);

        let function = module.add_function("jit", fn_type, None);
        let basic_block = context.append_basic_block(function, "entry");

        builder.position_at_end(basic_block);

        for node in ast {
            let recursive_builder = RecursiveBuilder::new(i32_type, &builder);
            let return_value = recursive_builder.build(&node);
            let _ = builder.build_return(Some(&return_value));
        }
        println!(
            "Generated LLVM IR: {}",
            function.print_to_string().to_string()
        );

        unsafe {
            let jit_function: JitFunction<JitFunc> = execution_engine.get_function("jit").unwrap();

            Ok(jit_function.call())
        }
    }
}

Filename: calculator/src/compiler/jit.rs

Finally, we can test it

assert_eq!(Jit::from_source("1 + 2").unwrap(), 3)

Run such tests locally with

cargo test jit --tests