调用链
tvm搜索算子在需要多线程运行的算子,是在codegen阶段时插入TVMBackendParallelLaunch的调用。
TVMBackendParallelLaunch 是tvm的线程池并行化入口,具体如下
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19/*!
* \brief The callback function to execute a parallel lambda
* \param task_id the task id of the function. //这里实际就是线程池线程编码,对应第几个线程
* \param penv The parallel environment backs the execution. // num_task, sync
* \param cdata The supporting closure data.
*/
typedef int (*FTVMParallelLambda)(int task_id, TVMParallelGroupEnv* penv, void* cdata);
/*!
* \brief Backend function for running parallel jobs.
*
* \param flambda The parallel function to be launched.
* \param cdata The closure data. // 可以认为时循环的变量 codegen时生成
* \param num_task Number of tasks to launch, can be 0, means launch
* with all available threads. // codegen 时写入的是0,运行时根据配置写入
*
* \return 0 when no error is thrown, -1 when failure happens
*/
int TVMBackendParallelLaunch(FTVMParallelLambda flambda, void* cdata, int num_task);
flambda的调用在单线程和多线程下略有区别。
单线程运行时
1
2
3
4
5
6
7
8if (num_workers == 1) {
std::atomic<int32_t> sync_counter{0};
TVMParallelGroupEnv env;
env.num_task = 1;
env.sync_handle = &sync_counter;
(*flambda)(0, &env, cdata);
return 0;
}1
2
3
4
5
6
7
8
9
10
11
12
13
14
15// launcher->Init(flambda, cdata, num_task, need_sync != 0);
this->cdata = cdata;
this->flambda = flambda;
this->env.num_task = num_task;
while (queue->Pop(&task, spin_count)) {
ICHECK(task.launcher != nullptr);
TVMParallelGroupEnv* penv = &(task.launcher->env);
void* cdata = task.launcher->cdata;
if ((*task.launcher->flambda)(task.task_id, penv, cdata) == 0) {
task.launcher->SignalJobFinish();
} else {
task.launcher->SignalJobError(task.task_id);
}
}TVMParallelGroupEnv* penv 包含了实际的运行时线程,运行时可以根据这个确定每个线程的工作区间和步长。
cdata则是线程运行时需要变量信息,闭包变量。
总结
对要并行的函数,实际上是按照lambda表达式的方式生成的。FTVMParallelLambda 的输入参数前两个是运行时确定的,第三个是捕获的外部变量。
codegen 过程
下面验证一下上述的猜测。
codegen过程中,实际上是在遍历tir Stmt的AST,因为生成的循环都是基于For的,调用过程也比较简单了。
1
2
3
4
5void CodeGenCPU::VisitStmt_(const ForNode* op) // ->
CreateParallelLaunch(For(op->loop_var, op->min, op->extent, op->kind, op->body,
op->thread_binding, op->annotations),
0, std::string("loop_parallel_") + op->loop_var->name_hint.c_str()); // ->
CodeGenCPU::VisitStmt_(const ForNode* op);parallel_env_.penv == nullptr 创建多线程调用函数,进入CreateParallelLaunch函数。
然后 再生成 For的遍历逻辑。this->VisitStmt(body); 这里的body其实还是For ,这时候就进入
1
2} else {
// already in parallel env.
1 |
|