Kaleidoscope: Adding JIT and Optimizer Support

Chapter 4 Introduction

这一章主要是两个部分: (1) 给语言添加一个 Optimizer, (2) 添加 JIT 支持.

Trivial Constant Folding

第二章中我们并没有添加任何的优化操作, 然而, IRBuilder 自动帮我们完成了下列的优化:

1
2
3
4
5
6
7
ready> def test(x) 1+2+x;
Read function definition:
define double @test(double %x) {
entry:
%addtmp = fadd double 3.000000e+00, %x
ret double %addtmp
}

如果没有这个优化 (Constant folding), 我们得到的代码里就会把 1 2 相加, 再把结果加上 x. 这个类型的优化叫做 Constant folding, 很多语言在实现它的时候都是在 AST 表达中进行的. 但是使用 LLVM 实现语言, 使用 IRBuilder 来生成代码时, 这个优化会自动完成.

然而, IRBuilder 自身的优化也受到一定程度的限制:

1
2
3
4
5
6
7
ready> def test(x) (1+2+x)*(x+(1+2));
Read the function definition:
define double @test(double %x) {
entry:
%addtmp = fadd double 3.000000e+00, %x
%addtmp1 = fadd double %x, 3.000000e+00
%multmp = fmul double %addtmp, %addtmp1

在这个例子中, LHS 和 RHS 显然是相同的值, 但 IRBuilder 的 local analysis 不可能能检测到并且优化这些代码. 这需要两个 transformation: (1) reassociation of expressions (重新关联表达式, 使得 + 的表示变得唯一), (2) Common SubExpression Elimination (CSE, 公共子表达式消除, 删除重复的 + 指令).

所以我们要使用 “pass“ 来完成这些优化.

LLVM Optimization Passes

PassManger 改版了, 这篇教程是基于 llvm::legacy::FunctionPassManager 的, 这个类可以在 LegacyPassManager.h 找到. 但我还是尝试用新版的来实现

参考Luke的回答, 使用继承了 PassInfoMixin class 的新版 Pass; 使用 FunctionAnalysisManager 类来注册 Pass; 使用 PassBuilder 工具类辅助.

答案链接

LLVM 提供了很多优化 pass, 而且同时允许编译器开发者定义, 并在合适的时候调用自己的 pass.

举个具体的例子, LLVM 提供了对整个 Module 进行处理的 pass, 同时也包括对单个函数的 pass. 在 Kaleidoscope 中, 我们使用的是针对单个函数的优化 pass, 也就是说用户定义一个我们就优化一个.

首先添加我们需要的全局变量:

1
2
3
static std::unique_ptr<llvm::FunctionPassManager> TheFPM;
static std::unique_ptr<llvm::FunctionAnalysisManager> TheFAM;
static std::unique_ptr<llvm::PassBuilder> ThePB;

然后我们需要设置好 FunctionPassManager, 用它来添加我们所想运行的 pass, 由于每个 Module 都需要 new 一个 FunctionPassManager, 我们就添加一个初始化它们的函数. 我们在该函数的最后才初始化 TheFAM, ThePB:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
void InitializeModuleAndPassManager()
{
// open a new Module
TheModule = std::make_unique<llvm::Module>("my cool jit", TheContext);

// Create a new pass manager attached to it
TheFPM = std::make_unique<llvm::FunctionPassManager>(TheModule.get());

// Combine instructions to form fewer, simple
// instructions. This pass does not modify the CFG. This
// pass is where algebraic simplification happens.
// (https://llvm.org/docs/Passes.html#instcombine-combine-redundant-instructions)
TheFPM->addPass(llvm::InstCombinePass());

// This pass reassociates commutative expressions in an order
// that is designed to promote better constant propagation, GCSE, LICM, PRE, etc.
// (https://llvm.org/docs/Passes.html#reassociate-reassociate-expressions)
TheFPM->addPass(llvm::ReassociatePass());

// This pass performs global value numbering to eliminate fully and partially redundant instructions.
// It also performs redundant load elimination.
// (https://llvm.org/docs/Passes.html#gvn-global-value-numbering)
TheFPM->addPass(llvm::GVN());

// Performs dead code elimination and basic block merging.
// (https://llvm.org/docs/Passes.html#simplifycfg-simplify-the-cfg)
TheFPM->addPass(llvm::SimplifyCFGPass());

// Using PassBuilder and FunctionAnalysis Manager Register analysis passes
ThePB = std::make_unique<llvm::PassBuilder>();
TheFAM = std::make_unique<llvm::FunctionAnalysisManager>();
ThePB->registerFunctionAnalyses(*TheFAM.get());
}

这部分代码首先初始化全局的 TheModule, 和 TheFPM. 然后使用一系列的 addPass() 方法来给 TheFPM 添加 Pass. 添加完之后, 使用 ThePBTheFAM 来注册这些 Pass.

然后我们得使用它, 所以在 FunctionAST::codegen() 中生成代码之后加上:

1
2
3
4
5
6
7
8
9
// Finish off the function
Builder.CreateRet(RetVal);
// Validate the generated code, checking for consistency
llvm::verifyFunction(*TheFunction);

// Optimization of the function code
TheFPM->run(*TheFunction, *TheFAM.get());

return TheFunction;

记得在 main() 里面调用 InitializeModuleAndPassManager()

Adding a JIT compiler

将代码生成为 IR 后, 我们可以使用多种不同的工具来对它进行处理. 比如我们可以优化它 (上一节), 或者将它 dump 下来之后, 将 IR 编译为汇编文件 (.s), 或者使用 JIT 来编译它. LLVM IR 就相同于编译器个部分之间的通货 (common currency).

在这节我们将要给解释器添加 JIT 支持, 最基本的要求是: 一旦我们将函数体输入进去之后, 它能够立刻算出 top-level 表达式的值. 比如在输入 1+2; 之后立即输出一个 3. 同时定义好的函数也能直接在命令行被调用.

首先我们先准备 native target 的环境, 并声明和初始化 JIT. 注意这里的 class KaleidoscopeJIT 教程里面引用的是 LLVM 源码中的类, 但我不是源码安装, 所以还是把它复制过来再使用.

首先定义全局变量 static std::unique_ptr<KaleidoscopeJIT> TheJIT;, class Kaleidoscope 定义在 KaleidoscopeJIT.

然后在 main() 函数中, 添加这几段:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
int main()
{
#ifdef JIT
// Initialize native target
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
llvm::InitializeNativeTargetAsmParser();
#endif
InitTokPrecedence();
// prime the first token
fprintf(stderr, "ready> ");
getNextToken();
TheJIT = std::make_unique<llvm::orc::KaleidoscopeJIT>();
InitializeModuleAndPassManager();
MainLoop();
#ifdef JIT
// Print out all of the generated code
TheModule->print(llvm::errs(), nullptr);
#endif
return 0;
}

同时要在 InitializeModuleAndPassManager() 函数中设置 data layout:

1
2
3
4
5
6
// open a new Module
TheModule = std::make_unique<llvm::Module>("my cool jit", TheContext);
TheModule->setDataLayout(TheJIT->getTargetMachine().createDataLayout());

// Create a new pass manager attached to it
TheFPM = std::make_unique<llvm::FunctionPassManager>(TheModule.get());

class Kaleidoscope 是一个简单的 JIT 构造类, 我们会在之后的章节中对它进行扩展. 目前的它提供的 API 非常简单: addModule() 给 JIT 添加了一个 Module, 使得函数能够执行; removeModule() 移除一个 Module, 释放相应的内存空间; findSymbol() 使我们能够通过 string name 来查找编译好代码的指针.

我们能使用这些 API 在 top-level 的 handle 函数中添加表达式值的计算:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
if (auto FnAST = ParseTopLevelExpr())
{
if (FnAST->codegen())
{
// JIT the module containing the anonymous expression, keeping a handle so we free it later
auto H = TheJIT->addModule(std::move(TheModule));
// Optimization of passes
InitializeModuleAndPassManager();

// Search the JIT for the __anon_expr symbol
auto ExprSymbol = TheJIT->findSymbol("__anon_expr");
assert(ExprSymbol && "Function not found");

// Get the symbol's address and cast it to the right type (takes no arguments, returns a double) so we can call it as a native function.
double (*FP)() = (double (*)())static_cast<intptr_t> (ExprSymbol.getAddress().get());
fprintf(stderr, "Evaluated to %f\n ", FP());

// Delete the anonymous expression module from the JIT
TheJIT->removeModule(H);
}
}

如果 Parsing 和 codegen 都成功执行, 下一步就是将包含 top-level 表达式的 Module 添加到 JIT, 我们使用 addModule() 这个方法, 触发这个 Module 中 所有函数的代码生成, 并返回一个 VModuleKey 的对象, 使得我们在之后可以移除 Module. 一旦这个 Module 被加入到 JIT, 它就不能被修改, 所以我们需要调用 InitializeModuleAndPassManager() 来打开一个新的 Module, 去持有接下来生成的代码.

一旦我们将 Module 添加到 JIT, 我们需要找到生成代码的指针, 所以我们使用 JIT 的findSymbol() 方法, 将 top-level 表达式的函数名 “__anon_expr” 当作参数传入, 就可以得到指令地址.

注意 “__anon_expr” 是要在 Parse top-level 表达式时传入的:

1
2
3
4
5
6
7
8
9
10
11
std::unique_ptr<FunctionAST> ParseTopLevelExpr()
{
if (auto E = ParseExpression())
{
// anonymous nullary function
auto Proto = std::make_unique<PrototypeAST>("__anon_expr", std::vector<std::string>());

return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
}
return nullptr;
}

然后我们使用 getAddress().get() 来获取这个函数在内存中的地址. 我们定义的匿名函数没有参数, 返回值为 double. 因为 LLVM JIT 编译器匹配了 native 平台的 ABI, 所以我们可以直接将函数地址转化成一个函数指针, 然后直接调用它. 这意味着, JIT 编译出的代码和静态链接到应用的 native 机器码没有区别.

最后, 由于我们不支持 top-level 的重复计算, 我们在代码生成的最后移除 Module, 释放调关联的内存. 但是我们在之前使用 InitializeModuleAndPassManager() 创建的 Module 仍然开启, 新的代码可以继续添加到其中.

再调用一次之前定义的函数会找不到符号, 因为一个 Module 是 JIT 分配的一个单元, 定义的函数位于前面 Module 中, 当我们把一个 Module 移除之后, 我们就相当于删除了那个 Module 中所有函数的定义, 所以再次调用就找不到符号.

最简单的方式就是将匿名表达式放在一个单独的 Module, 和其他函数定义分开.

实际上, 我们想更进一步, 将每个函数都放到一个单独的 Module, 这样就更加像一个 REPL, 函数能被定义多次, 引用它时, 都是返回最近的定义.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;

llvm::Function *FunctionAST::codegen()
{
// Transfer ownership of the prototype to the FunctionProtos
// map, but keep a reference to it use below
auto &P = *Proto;
FunctionProtos[Proto->getName()] = std::move(Proto);
// Check for an existing function from a previous `extern` declaration
llvm::Function *TheFunction = getFunction(P.getName());

if (!TheFunction)
{
return nullptr;
}
....
}

llvm::Value *CallExprAST::codegen()
{
// Lookup the name in the global module table
llvm::Function *CalleeF = getFunction(Callee);
...
}

llvm::Function *getFunction(std::string Name) {
// First, check if the function has already been added to the
// current mode
if (auto *F = TheModule->getFunction(Name)) {
return F;
}

// If not exist, check wether we can codegen the declaration // from some existing prototype
auto FI = FunctionProtos.find(Name);
if (FI != FunctionProtos.end()) {
return FI->second->codegen();
}

// if no existing prototype exists, return null
return nullptr;
}

首先添加一个全局的 map<string, unique_ptr<PrototypeAST>> FunctionProtos, 用它来保存每个函数最近的一次函数原型. 并且我们添加了一个 getFunction() , 代替 TheModule->getFunction(). 自定义的 getFunction(), 首先它搜寻 Module, 看有没有存在的函数声明, 如果没有就生成函数原型的代码. 所以我们可以把之前 Function *FunctionAST::codegen() 之中的那个生成函数原型的代码片段删除. 同时在 Function *FunctionAST::codegen() 中, 我们首先更新 FunctionProtos, 再调用 getFunction(). 通过使用这个全局的 FunctionProtos, 我们能获取之前声明过的所有函数.

我们还需要更新 HandleDefinition()HandleExtern():

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
void HandleDefinition()
{
if (auto FnAST = ParseDefinition())
{
if (auto *FnIR = FnAST->codegen())
{
fprintf(stderr, "Read the function definition: ");
FnIR->print(llvm::errs());
fprintf(stderr, "\n");
TheJIT->addModule(std::move(TheModule));
InitializeModuleAndPassManager();
}
}
else
{
// Skip token for error recovery
getNextToken();
}
}

void HandleExtern()
{
if (auto ProtoAST = ParseExtern())
{
if (auto *FnIR = ProtoAST->codegen())
{
fprintf(stderr, "Read extern: ");
FnIR->print(llvm::errs());
fprintf(stderr, "\n");
FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
}
}
else
{
// Skip token for error recovery
getNextToken();
}
}

HandleDefinition() 中, 我们需要将刚刚定义好的函数所在的 Module 添加到 JIT, 然后打开一个新的 Module, 这样就使得每个函数都在不同的 Module 中. 在 HandleExtern() 中, 只需要将这个外部声明添加到 FunctionProtos 中即可.

并且, 对于外部声明 e.g., extern sin(x), JIT 有一套直接的符号解析机制: JIT 按照时间逆序, 搜寻所有添加的 Module, 来查找符号的定义. 如果没找到, 它就 fall back 去调用 dlsym("sin"), 去 JIT 的地址空间里面去找这个符号的定义, 然后去调用它 (libm).

后面我们会进一步讨论 JIT 的符号解析机制, 并调整它来实现一些 feature, 比如安全性, 动态代码生成, 甚至是 lazy evaluation.

调整符号解析规则一个最直接的好处就是, 我们能使用任意的 C++ 代码来扩展语言:

1
2
3
4
extern "C" DLLEXPORT double putchard(double X) {
fputc((char)X, stderr);
return 0;
}

链接时添加 -rdynamic, 使得 dlopen() 打开的 shared object 能解析到自己程序中的符号.