Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
zhczhong committed Jul 30, 2024
1 parent ed5180d commit 974b8ca
Showing 1 changed file with 30 additions and 11 deletions.
41 changes: 30 additions & 11 deletions lib/gc/Analysis/MatmulConfigAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,9 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,

template <typename T>
static llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
std::vector<T> arry) {
std::vector<T> array) {
ss << "[";
for (auto [idx, a] : llvm::enumerate(arry)) {
if (idx != 0) {
ss << ", ";
}
ss << a;
}
llvm::interleaveComma(array, ss);
ss << "]";
return ss;
}
Expand Down Expand Up @@ -174,7 +169,7 @@ std::vector<MatmulConfig>
filterConfigByCostModel(ArrayRef<MatmulConfig> configs,
linalg::LinalgOp &linalgOp, ArrayRef<uint32_t> shape,
SystemDesc &sysDesc, const CostModelFn &costModel,
float eliminationRatio = 0.5, float threshold = -1) {
float preserveRatio = 0.5, float threshold = -1) {
std::vector<MatmulConfig> result;
std::vector<float> costs;
std::vector<size_t> idx;
Expand All @@ -185,8 +180,7 @@ filterConfigByCostModel(ArrayRef<MatmulConfig> configs,
std::stable_sort(idx.begin(), idx.end(), [&costs](size_t i1, size_t i2) {
return costs[i1] < costs[i2];
});
double thresholdCost =
costs[idx[(size_t)(eliminationRatio * configs.size())]];
double thresholdCost = costs[idx[(size_t)(preserveRatio * configs.size())]];
thresholdCost =
threshold < thresholdCost && threshold > 0 ? threshold : thresholdCost;
for (size_t i = 0; i < configs.size(); i++) {
Expand All @@ -210,6 +204,11 @@ std::vector<MatmulConfig>
prepareConfigCandidates(Operation *root, SystemDesc &sysDesc,
ArrayRef<uint32_t> shape,
ArrayRef<uint32_t> givenInnermostBlock) {
if (shape.size() < 3) {
LLVM_DEBUG(llvm::dbgs()
<< "The shape is invalid, no candidate is generated\n");
return {};
}
std::vector<MatmulConfig> configs;
uint32_t threads = sysDesc.getNumThreads();
std::vector<uint32_t> MThreadsCandidates =
Expand Down Expand Up @@ -290,6 +289,21 @@ prepareConfigCandidates(Operation *root, SystemDesc &sysDesc,
return configs;
}

bool validateConfig(const MatmulConfig &cfg) {
if (cfg.MThreads <= 0 || cfg.NThreads <= 0 || cfg.KThreads <= 0 ||
cfg.MBlock <= 0 || cfg.NBlock <= 0 || cfg.KBlock <= 0 ||
cfg.innerMostMBlock <= 0 || cfg.innerMostNBlock <= 0 ||
cfg.innerMostKBlock <= 0) {
return false;
}
if (cfg.MBlock % cfg.innerMostMBlock != 0 ||
cfg.NBlock % cfg.innerMostNBlock != 0 ||
cfg.KBlock % cfg.innerMostKBlock != 0) {
return false;
}
return true;
}

// read the config from the attributes for tuning
bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
size_t cfgItemCnt = 0;
Expand Down Expand Up @@ -323,7 +337,12 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
cfgItemCnt++;
}
}
return cfgItemCnt == 9;
if (validateConfig(config)) {
return cfgItemCnt == 9;
} else {
LLVM_DEBUG(llvm::dbgs() << "The predefined config is invalid\n");
return false;
}
}

// Analyze the workload and system description to generate the default config
Expand Down

0 comments on commit 974b8ca

Please sign in to comment.