Type Inference
Type inference is like filling in a crossword puzzle. Some squares have letters (explicit annotations), others are blank (Unknown). You use the constraints - “this must be 5 letters”, “it crosses with CAT” - to fill in the blanks. Type inference uses constraints like “this is added to an int, so it must be int” to fill in Unknown types.
In the previous chapter, we saw that the parser creates an AST with many Type::Unknown values. The type checker’s job is to figure out what those unknown types should be. This process is called type inference.
What Kind of Type Inference?
There are different approaches to type inference:
| Approach | Used By | Polymorphism | Complexity |
|---|---|---|---|
| Hindley-Milner | Haskell, ML, OCaml | Full parametric | High |
| Local Type Inference | TypeScript, Go, Rust, Swift | Limited | Low |
| Bidirectional Type Checking | Scala, Agda | Configurable | Medium |
We use local type inference (also called “flow-based” inference). This is simpler than Hindley-Milner but covers the common cases. The key difference:
- Hindley-Milner: Can infer polymorphic types like
fn identity<T>(x: T) -> Twithout any annotations - Local inference: Requires type annotations at function boundaries; infers types within function bodies
Our approach is similar to what TypeScript, Go, and Rust use. It is practical, easy to understand, and sufficient for our language.
The Algorithm in Pseudocode
Before diving into Rust code, here is the algorithm in pseudocode:
ALGORITHM: Local Type Inference
INPUT: AST with some types marked as Unknown
OUTPUT: AST with all types filled in, or an error
1. COLLECT SIGNATURES:
for each function definition in program:
record (function_name -> function_type) in environment
2. TYPECHECK EACH STATEMENT:
for each statement in program:
typecheck_statement(statement, environment)
FUNCTION typecheck_statement(stmt, env):
case stmt of:
| Assignment(name, value):
inferred_type = typecheck_expr(value, env)
if explicit_annotation exists:
check annotation matches inferred_type
add (name -> inferred_type) to env
| Function(name, params, body):
local_env = copy of env
for each (param_name, param_type) in params:
add (param_name -> param_type) to local_env
for each stmt in body:
typecheck_statement(stmt, local_env)
| Return(expr):
typecheck_expr(expr, env)
check result matches declared return type
FUNCTION typecheck_expr(expr, env) -> Type:
case expr of:
| Int(n): return Int
| Bool(b): return Bool
| Var(name): return lookup(name, env)
| Binary(op, left, right):
left_type = typecheck_expr(left, env)
right_type = typecheck_expr(right, env)
return apply_op_rules(op, left_type, right_type)
| Call(name, args):
func_type = lookup(name, env)
for each (arg, expected_type) in zip(args, func_type.params):
actual_type = typecheck_expr(arg, env)
check actual_type matches expected_type
return func_type.return_type
| If(cond, then, else):
check typecheck_expr(cond, env) == Bool
then_type = typecheck_block(then, env)
else_type = typecheck_block(else, env)
return unify(then_type, else_type)
The key insight: types flow forward from known sources (literals, parameters) through operations into variables.
The Core Insight
Here is the key idea: types flow through expressions. If we know the type of the inputs, we can figure out the type of the output.
Consider x = 1 + 2. How does the compiler know x is an int?
1is an integer literal → type isInt2is an integer literal → type isInt+with twoIntoperands → producesInt- We are assigning an
Inttox→xmust beInt
The type “flows” from the literals, through the operator, into the variable. No explicit annotation needed.
Step-by-Step Example
Let us trace through this code in detail:
x = 42
y = x * 2 + 10
is_big = y > 50
Step 1: Parse (types are Unknown)
After parsing, the AST looks like this (simplified):
Assignment { name: "x", value: Int(42), ty: Unknown }
Assignment { name: "y", value: Binary(Var("x") * Int(2) + Int(10)), ty: Unknown }
Assignment { name: "is_big", value: Binary(Var("y") > Int(50)), ty: Unknown }
Every expression has ty: Unknown. We do not know the types yet.
Step 2: Type check first assignment
For x = 42:
-
Check the value
42:- It is an
Intliteral - Set its type:
Int(42).ty = Int
- It is an
-
Infer the variable type:
- No explicit annotation, so we use the value’s type
xhas typeInt- Add to environment:
{ x: Int }
Step 3: Type check second assignment
For y = x * 2 + 10:
-
Check
x * 2:- Look up
xin environment →Int 2is anIntliteral*withInt * Int→ producesInt- Set type:
(x * 2).ty = Int
- Look up
-
Check
(x * 2) + 10:- Left side
(x * 2)has typeInt(from step above) - Right side
10is anIntliteral +withInt + Int→ producesInt- Set type:
((x * 2) + 10).ty = Int
- Left side
-
Infer the variable type:
- Value has type
Int yhas typeInt- Add to environment:
{ x: Int, y: Int }
- Value has type
Step 4: Type check third assignment
For is_big = y > 50:
-
Check
y > 50:- Look up
yin environment →Int 50is anIntliteral>withInt > Int→ producesBool(comparisons return boolean)- Set type:
(y > 50).ty = Bool
- Look up
-
Infer the variable type:
- Value has type
Bool is_bighas typeBool- Add to environment:
{ x: Int, y: Int, is_big: Bool }
- Value has type
Final Result
All Unknown types are now resolved:
Assignment { name: "x", value: Int(42), ty: Int }
Assignment { name: "y", value: Binary(...), ty: Int }
Assignment { name: "is_big", value: Binary(...), ty: Bool }
The compiler inferred all the types without us writing a single type annotation.
Typing Rules
The type checker applies these typing rules:
| Expression | Rule | Result Type |
|---|---|---|
42 | Integer literals are Int | Int |
true, false | Boolean literals are Bool | Bool |
x (variable) | Look up in type environment | env[x] |
a + b, a * b, etc. | Both operands must be Int | Int |
a < b, a > b, a == b | Both operands must be Int | Bool |
!a | Operand must be Bool | Bool |
-a | Operand must be Int | Int |
f(args...) | Args must match function parameter types | Function’s return type |
if (c) { a } else { b } | c must be Bool; a and b must unify | Unified type of a and b |
These rules are applied recursively, bottom-up through the expression tree.
Type Unification
Unification is the process of checking if two types are compatible and finding a common type. This is a key operation in type inference.
Here is the pseudocode:
FUNCTION unify(type1, type2) -> Type or Error:
if type1 == type2:
return type1 # Same types match
if type1 == Unknown:
return type2 # Unknown takes the other type
if type2 == Unknown:
return type1 # Unknown takes the other type
if type1 is Function and type2 is Function:
unify each parameter type
unify return types
return unified function type
else:
return Error("Cannot unify type1 with type2")
Our implementation:
/// Try to unify this type with another type
/// Returns the unified type if successful, or an error message
pub fn unify(&self, other: &Type) -> Result<Type, String> {
match (self, other) {
// Same types unify
(Type::Int, Type::Int) => Ok(Type::Int),
(Type::Bool, Type::Bool) => Ok(Type::Bool),
(Type::Unit, Type::Unit) => Ok(Type::Unit),
// Unknown can unify with anything
(Type::Unknown, t) | (t, Type::Unknown) => Ok(t.clone()),
// Function types must have compatible signatures
(
Type::Function {
params: p1,
ret: r1,
},
Type::Function {
params: p2,
ret: r2,
},
) if p1.len() == p2.len() => {
let params: Result<Vec<_>, _> =
p1.iter().zip(p2.iter()).map(|(a, b)| a.unify(b)).collect();
let ret = r1.unify(r2)?;
Ok(Type::Function {
params: params?,
ret: Box::new(ret),
})
}
// Type mismatch
_ => Err(format!(
"Type mismatch: expected {:?}, got {:?}",
self, other
)),
}
}
Let us understand each case:
| Unify | Result | Why |
|---|---|---|
Int.unify(Int) | Ok(Int) | Same types match |
Bool.unify(Bool) | Ok(Bool) | Same types match |
Unknown.unify(Int) | Ok(Int) | Unknown takes on the concrete type |
Int.unify(Unknown) | Ok(Int) | Unknown takes on the concrete type |
Int.unify(Bool) | Err | Incompatible types cannot unify |
The Unknown case is the heart of type inference. When we unify Unknown with a concrete type, we learn what the unknown type should be.
The Type Environment
The type environment (also called symbol table or context) maps names to types:
type TypeEnv = HashMap<String, Type>;
The environment is:
- Extended when we declare a variable or enter a function (adding new bindings)
- Queried when we reference a variable (looking up its type)
- Scoped - inner scopes can shadow outer bindings
This is the same concept as the runtime environment in Firstlang’s interpreter, but storing types instead of values.
Function Type Inference
Functions are trickier because we need to handle:
- Parameters (types come from annotations)
- Local variables (types are inferred)
- Return value (must match declared return type)
Consider:
def compute(a: int, b: int) -> int {
temp = a + b # What type is temp?
doubled = temp * 2 # What type is doubled?
return doubled + 1
}
The type checker:
-
Adds parameters to environment:
{ a: Int, b: Int } -
Checks
temp = a + b:aisInt,bisInta + bisInttempisInt- Environment:
{ a: Int, b: Int, temp: Int }
-
Checks
doubled = temp * 2:tempisInt,2isInttemp * 2isIntdoubledisInt- Environment:
{ a: Int, b: Int, temp: Int, doubled: Int }
-
Checks
return doubled + 1:doubledisInt,1isIntdoubled + 1isInt- Declared return type is
Int- matches
All types are inferred from the parameter types flowing through the expressions.
The Two-Pass Algorithm
The type checker uses two passes:
/// Type check and infer types for a program
pub fn typecheck(program: &mut Program) -> Result<(), String> {
let mut env = TypeEnv::new();
// First pass: collect function signatures
for stmt in program.iter() {
if let Stmt::Function {
name,
params,
return_type,
..
} = stmt
{
let param_types: Vec<Type> = params.iter().map(|(_, t)| t.clone()).collect();
let func_type = Type::Function {
params: param_types,
ret: Box::new(return_type.clone()),
};
env.insert(name.clone(), func_type);
}
}
// Second pass: type check each statement
for stmt in program.iter_mut() {
typecheck_stmt(stmt, &mut env)?;
}
Ok(())
}
Pass 1: Collect function signatures
We scan all function definitions and record their types before checking any bodies. Why? Because functions can call each other (mutual recursion):
def isEven(n: int) -> bool {
if (n == 0) { return true }
else { return isOdd(n - 1) }
}
def isOdd(n: int) -> bool {
if (n == 0) { return false }
else { return isEven(n - 1) }
}
When checking isEven, we need to know the type of isOdd. By collecting all signatures first, mutual recursion works.
Pass 2: Check each statement
Now we go through each statement, inferring types as we go.
Type Checking Expressions
Here is the complete typecheck_expr function:
fn typecheck_expr(expr: &mut TypedExpr, env: &TypeEnv) -> Result<(), String> {
match &mut expr.expr {
Expr::Int(_) => {
expr.ty = Type::Int;
}
Expr::Bool(_) => {
expr.ty = Type::Bool;
}
Expr::Var(name) => {
if let Some(ty) = env.get(name) {
expr.ty = ty.clone();
} else {
return Err(format!("Undefined variable: {}", name));
}
}
Expr::Unary { op, expr: inner } => {
typecheck_expr(inner, env)?;
match op {
UnaryOp::Neg => {
if inner.ty != Type::Int {
return Err(format!("Cannot negate non-integer type: {}", inner.ty));
}
expr.ty = Type::Int;
}
UnaryOp::Not => {
if inner.ty != Type::Bool {
return Err(format!("Cannot negate non-boolean type: {}", inner.ty));
}
expr.ty = Type::Bool;
}
}
}
Expr::Binary { op, left, right } => {
typecheck_expr(left, env)?;
typecheck_expr(right, env)?;
match op {
BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div | BinaryOp::Mod => {
if left.ty != Type::Int || right.ty != Type::Int {
return Err(format!(
"Arithmetic operation requires int operands, got {} and {}",
left.ty, right.ty
));
}
expr.ty = Type::Int;
}
BinaryOp::Lt | BinaryOp::Gt | BinaryOp::Le | BinaryOp::Ge => {
if left.ty != Type::Int || right.ty != Type::Int {
return Err(format!(
"Comparison requires int operands, got {} and {}",
left.ty, right.ty
));
}
expr.ty = Type::Bool;
}
BinaryOp::Eq | BinaryOp::Ne => {
let _ = left.ty.unify(&right.ty)?;
expr.ty = Type::Bool;
}
}
}
Expr::Call { name, args } => {
// Look up function type
let func_type = env
.get(name)
.ok_or_else(|| format!("Undefined function: {}", name))?
.clone();
if let Type::Function { params, ret } = func_type {
// Check argument count
if args.len() != params.len() {
return Err(format!(
"Function {} expects {} arguments, got {}",
name,
params.len(),
args.len()
));
}
// Type check each argument
for (arg, param_type) in args.iter_mut().zip(params.iter()) {
typecheck_expr(arg, env)?;
let _ = arg.ty.unify(param_type)?;
}
expr.ty = *ret;
} else {
return Err(format!("{} is not a function", name));
}
}
Expr::If {
cond,
then_branch,
else_branch,
} => {
typecheck_expr(cond, env)?;
if cond.ty != Type::Bool {
return Err(format!("If condition must be bool, got {}", cond.ty));
}
// Type check branches
let mut then_env = env.clone();
let mut then_type = Type::Unit;
for stmt in then_branch.iter_mut() {
then_type = typecheck_stmt(stmt, &mut then_env)?;
}
let mut else_env = env.clone();
let mut else_type = Type::Unit;
for stmt in else_branch.iter_mut() {
else_type = typecheck_stmt(stmt, &mut else_env)?;
}
// Branches must have same type
expr.ty = then_type.unify(&else_type)?;
}
Expr::While { cond, body } => {
typecheck_expr(cond, env)?;
if cond.ty != Type::Bool {
return Err(format!("While condition must be bool, got {}", cond.ty));
}
let mut body_env = env.clone();
for stmt in body.iter_mut() {
typecheck_stmt(stmt, &mut body_env)?;
}
expr.ty = Type::Unit;
}
Expr::Block(stmts) => {
let mut block_env = env.clone();
let mut last_type = Type::Unit;
for stmt in stmts.iter_mut() {
last_type = typecheck_stmt(stmt, &mut block_env)?;
}
expr.ty = last_type;
}
}
Ok(())
}
The pattern is always the same:
- Recursively type check sub-expressions
- Apply the typing rule for this expression kind
- Set the type on this expression
When Inference Fails
Type inference is not magic. It fails when there is not enough information:
This would fail - what type is x?
x = some_function_that_could_return_anything()
Or when types conflict:
x = 42
x = true # Error: x is Int, cannot assign Bool
We report errors with helpful messages:
let result = typecheck("1 + true");
// Error: "Arithmetic operation requires int operands, got int and bool"
let result = typecheck("add(1, true)");
// Error: "Type mismatch: expected int, got bool"
Limitations of Local Inference
Our inference cannot handle some things that Hindley-Milner can:
Hindley-Milner could infer: identity : forall a. a -> a
def identity(x) {
return x
}
We require annotations:
def identity(x: int) -> int {
return x
}
For a simple language like Secondlang, this is fine. The annotation burden is low (just function boundaries), and the implementation is much simpler.
Comparison with Other Systems
| Feature | Secondlang | TypeScript | Haskell |
|---|---|---|---|
| Variable inference | Yes | Yes | Yes |
| Function param inference | No | Partial | Yes |
| Polymorphism | No | Yes (generics) | Yes (parametric) |
| Bidirectional | No | Yes | Partial |
Summary
Type inference works by:
- Starting with known types: literals (
42→ Int,true→ Bool) and annotated parameters - Flowing types through expressions: operators, function calls, assignments
- Recording types in the environment: so variables can be looked up later
- Unifying types: checking compatibility and resolving
Unknown - Reporting errors: when types do not match
The beauty is that most of the time, you only need to annotate function parameters and return types. Everything else is inferred automatically.
Further Reading
- Type Inference on Wikipedia
- Hindley-Milner Type System
- Unification in Computer Science
- Types and Programming Languages by Benjamin Pierce - the definitive textbook
Try It Yourself
Run the inference example:
rustup run nightly cargo run -- examples/inference.sl
This demonstrates all the inference concepts in action.
Testing
cargo test typeck
At this point, you should be able to:
- Compile
x = 5and havexinferred asint - Get an error for
x = 5 + true(type mismatch) - Compile functions without annotating local variable types
In the next chapter, we look at optimizations we can do on the typed AST before generating code.