AST Traversal Patterns
Recall from the previous section that JITing our add function was very detailed and cumbersome to write. Fortunately, there are some useful patterns for traversing complicated ASTs (and IRs)
- Builder pattern
- Visitor pattern (Will be introduced in chapter 4)
Builder Pattern
Recall how we have interpreted our AST by traversing recursively and evaluating the nodes
struct Eval;
impl Eval {
pub fn new() -> Self {
Self
}
pub fn eval(&self, node: &Node) -> i32 {
match node {
Node::Int(n) => *n,
Node::UnaryExpr { op, child } => {
let child = self.eval(child);
match op {
Operator::Plus => child,
Operator::Minus => -child,
}
}
Node::BinaryExpr { op, lhs, rhs } => {
let lhs_ret = self.eval(lhs);
let rhs_ret = self.eval(rhs);
match op {
Operator::Plus => lhs_ret + rhs_ret,
Operator::Minus => lhs_ret - rhs_ret,
}
}
}
}
}
Filename: calculator/src/compiler/interpreter.rs
but instead, we can take advantage of the inkwell Builder and recursively traverse our Calc
AST as follows
struct RecursiveBuilder<'a> {
i32_type: IntType<'a>,
builder: &'a Builder<'a>,
}
impl<'a> RecursiveBuilder<'a> {
pub fn new(i32_type: IntType<'a>, builder: &'a Builder) -> Self {
Self { i32_type, builder }
}
pub fn build(&self, ast: &Node) -> IntValue {
match ast {
Node::Int(n) => self.i32_type.const_int(*n as u64, true),
Node::UnaryExpr { op, child } => {
let child = self.build(child);
match op {
Operator::Minus => child.const_neg(),
Operator::Plus => child,
}
}
Node::BinaryExpr { op, lhs, rhs } => {
let left = self.build(lhs);
let right = self.build(rhs);
match op {
Operator::Plus => self
.builder
.build_int_add(left, right, "plus_temp")
.unwrap(),
Operator::Minus => self
.builder
.build_int_sub(left, right, "minus_temp")
.unwrap(),
}
}
}
}
}
and similar to our addition example, we can JIT the builder output
pub struct Jit;
impl Compile for Jit {
type Output = Result<i32>;
fn from_ast(ast: Vec<Node>) -> Self::Output {
let context = Context::create();
let module = context.create_module("calculator");
let builder = context.create_builder();
let execution_engine = module
.create_jit_execution_engine(OptimizationLevel::None)
.unwrap();
let i32_type = context.i32_type();
let fn_type = i32_type.fn_type(&[], false);
let function = module.add_function("jit", fn_type, None);
let basic_block = context.append_basic_block(function, "entry");
builder.position_at_end(basic_block);
for node in ast {
let recursive_builder = RecursiveBuilder::new(i32_type, &builder);
let return_value = recursive_builder.build(&node);
let _ = builder.build_return(Some(&return_value));
}
println!(
"Generated LLVM IR: {}",
function.print_to_string().to_string()
);
unsafe {
let jit_function: JitFunction<JitFunc> = execution_engine.get_function("jit").unwrap();
Ok(jit_function.call())
}
}
}
Filename: calculator/src/compiler/jit.rs
Finally, we can test it
assert_eq!(Jit::from_source("1 + 2").unwrap(), 3)
Run such tests locally with
cargo test jit --tests