1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
| void CodeGenCPU::VisitStmt_(const ForNode* op) { ICHECK(is_zero(op->min)); if (op->kind == ForKind::kSerial || op->kind == ForKind::kUnrolled) { CodeGenLLVM::VisitStmt_(op); } else if (op->kind == ForKind::kParallel) { if (parallel_env_.penv == nullptr) { CreateParallelLaunch(For(op->loop_var, op->min, op->extent, op->kind, op->body, op->thread_binding, op->annotations), 0, std::string("loop_parallel_") + op->loop_var->name_hint.c_str()); } else { ICHECK(parallel_env_.task_id.defined()); ICHECK(parallel_env_.num_task.defined()); ICHECK(parallel_env_.penv != nullptr); DataType t = op->extent.dtype(); PrimExpr num_task = cast(t, parallel_env_.num_task); PrimExpr task_id = cast(t, parallel_env_.task_id); ICHECK(!parallel_env_.in_parallel_loop) << "Nested parallel loop is not supported by threadpool, try fuse them instead"; parallel_env_.in_parallel_loop = true; if (parallel_env_.stride_pattern) { CreateSerialFor(MakeValue(task_id), MakeValue(op->extent), MakeValue(num_task), op->loop_var, op->body); } else { PrimExpr step = (op->extent + num_task - make_const(t, 1)) / num_task; PrimExpr begin = min(task_id * step, op->extent); PrimExpr end = min((task_id + make_const(t, 1)) * step, op->extent); CreateSerialFor(MakeValue(begin), MakeValue(end), llvm::ConstantInt::getSigned(GetLLVMType(end), 1), op->loop_var, op->body); } parallel_env_.in_parallel_loop = false; ++parallel_env_.parallel_loop_count; } } else { LOG(FATAL) << "cannot handle for type " << op->kind; } }
void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task, std::string name) { llvm::Function* f = llvm::Function::Create(ftype_tvm_parallel_lambda_, llvm::Function::PrivateLinkage, "__tvm_parallel_lambda", module_.get()); SetTargetAttributes(f);
Array<Var> vfields = tir::UndefinedVars(body, {}); uint64_t nbytes; TypedPointer cdata = PackClosureData(vfields, &nbytes, "closure_" + name); #if TVM_LLVM_VERSION >= 90 auto launch_callee = llvm::FunctionCallee(ftype_tvm_parallel_launch_, RuntimeTVMParallelLaunch()); #else auto launch_callee = RuntimeTVMParallelLaunch(); #endif llvm::BasicBlock* par_launch_end = CheckCallSuccess(builder_->CreateCall( launch_callee, {f, builder_->CreatePointerCast(cdata.addr, t_void_p_), ConstInt32(num_task)})); auto* lambda_entry = llvm::BasicBlock::Create(*llvm_target_->GetContext(), "parallel_closure_entry", f); builder_->SetInsertPoint(lambda_entry); auto it = f->arg_begin(); llvm::Value* task_id = &(*it++); task_id->setName("task_id"); llvm::Value* penv = &(*it++); cdata.addr = builder_->CreatePointerCast(&(*it++), cdata.addr->getType()); std::unordered_map<const VarNode*, llvm::Value*> new_vmap; UnpackClosureData(cdata, vfields, &new_vmap); ParallelEnv par_env; par_env.task_id = Var("task_id", DataType::Int(32)); par_env.num_task = Var("num_task", DataType::Int(32)); new_vmap[par_env.task_id.get()] = task_id; new_vmap[par_env.num_task.get()] = builder_->CreateLoad( t_int32_, builder_->CreateInBoundsGEP(t_tvm_parallel_group_env_, penv, {ConstInt32(0), ConstInt32(1)}), "num_task"); par_env.penv = penv; auto new_analyzer = std::make_unique<arith::Analyzer>(); std::swap(function_, f); std::swap(parallel_env_, par_env); std::swap(analyzer_, new_analyzer); std::swap(var_map_, new_vmap); this->VisitStmt(body); builder_->CreateRet(ConstInt32(0)); std::swap(var_map_, new_vmap); std::swap(analyzer_, new_analyzer); std::swap(parallel_env_, par_env); std::swap(function_, f); ICHECK_NE(par_env.parallel_loop_count, 0) << "Cannot find parallel loop within parallel launch"; builder_->SetInsertPoint(par_launch_end); }
|