Skip to content

Commit 3282bfe

Browse files
committed
refactoring; lit tests added
1 parent b429a91 commit 3282bfe

4 files changed

+1672
-79
lines changed

src/passes/Asyncify.cpp

+91-79
Original file line numberDiff line numberDiff line change
@@ -1274,32 +1274,37 @@ struct AsyncifyAddCatchCounters : public Pass {
12741274
CountersBuilder builder(*module_);
12751275
BranchUtils::BranchTargets branchTargets(func->body);
12761276

1277-
// with this walker we will find level of "nesting" for each expression
1278-
// ... - +0
1277+
// with this walker we will assign count of enclosing catch block to
1278+
// each expression
1279+
// ... - 0
12791280
// catch
1280-
// ... - +1
1281+
// ... - 1
12811282
// catch
1282-
// ... - +2
1283-
std::unordered_map<Expression*, int> expressionNestedLevel;
1283+
// ... - 2
1284+
std::unordered_map<Expression*, int> expressionCatchCount;
12841285
struct NestedLevelWalker
12851286
: public PostWalker<NestedLevelWalker,
12861287
UnifiedExpressionVisitor<NestedLevelWalker>> {
1287-
std::unordered_map<Expression*, int>* expressionNestedLevel;
1288-
int nestedLevel = 0;
1288+
std::unordered_map<Expression*, int>* expressionCatchCount;
1289+
int catchCount = 0;
12891290

12901291
static void doStartCatch(NestedLevelWalker* self, Expression** currp) {
1291-
self->nestedLevel++;
1292+
self->catchCount++;
12921293
}
12931294

12941295
static void doEndCatch(NestedLevelWalker* self, Expression** currp) {
1295-
self->nestedLevel--;
1296+
self->catchCount--;
12961297
}
12971298

12981299
static void scan(NestedLevelWalker* self, Expression** currp) {
12991300
auto curr = *currp;
13001301
if (curr->_id == Expression::Id::TryId) {
1302+
self->expressionCatchCount->insert(
1303+
std::make_pair<>(curr, self->catchCount));
13011304
auto& catchBodies = curr->cast<Try>()->catchBodies;
13021305
for (Index i = 0; i < catchBodies.size(); i++) {
1306+
self->expressionCatchCount->insert(
1307+
std::make_pair<>(catchBodies[i], self->catchCount));
13031308
self->pushTask(doEndCatch, currp);
13041309
self->pushTask(NestedLevelWalker::scan, &catchBodies[i]);
13051310
self->pushTask(doStartCatch, currp);
@@ -1314,146 +1319,153 @@ struct AsyncifyAddCatchCounters : public Pass {
13141319
}
13151320

13161321
void visitExpression(Expression* curr) {
1317-
expressionNestedLevel->insert(std::make_pair<>(curr, nestedLevel));
1322+
expressionCatchCount->insert(std::make_pair<>(curr, catchCount));
13181323
}
13191324
};
13201325
NestedLevelWalker nestedLevelWalker;
1321-
nestedLevelWalker.expressionNestedLevel = &expressionNestedLevel;
1326+
nestedLevelWalker.expressionCatchCount = &expressionCatchCount;
13221327
nestedLevelWalker.walk(func->body);
13231328

1324-
// with this walker we will handle those counters:
1329+
// with this walker we will handle those changes of counter:
13251330
// - entering into catch (= pop) +1
1326-
// - return -N (nested catch count up to root)
1327-
// - break -N (nested catch count up to label)
1331+
// - return -1
1332+
// - break -1
13281333
// - exiting from catch -1
13291334
struct AddCountersWalker : public PostWalker<AddCountersWalker> {
13301335
Function* func;
13311336
CountersBuilder* builder;
13321337
BranchUtils::BranchTargets* branchTargets;
1333-
std::unordered_map<Expression*, int>* expressionNestedLevel;
1334-
int labelNum = 0;
1338+
std::unordered_map<Expression*, int>* expressionCatchCount;
1339+
int finallyNum = 0;
1340+
int popNum = 0;
1341+
1342+
int getCatchCount(Expression* expression) {
1343+
auto it = expressionCatchCount->find(expression);
1344+
assert(it != expressionCatchCount->end());
1345+
return it->second;
1346+
}
13351347

13361348
// Each catch block except catch_all should have pop instruction
1337-
// We increment counter each time when pop happens (= entering catch
1338-
// block)
1349+
// We increment counter each time when we enter top-level catch block
13391350
void visitPop(Pop* pop) {
1340-
replaceCurrent(builder->makeSequence(pop, builder->makeInc()));
1351+
if (getCatchCount(pop) == 1) {
1352+
auto name =
1353+
func->name.toString() + "-pop-" + std::to_string(++popNum);
1354+
replaceCurrent(
1355+
builder->makeBlock(name, {pop, builder->makeInc()}, Type::none));
1356+
}
13411357
}
13421358
void visitLocalSet(LocalSet* set) {
13431359
auto block = set->value->dynCast<Block>(); // from visitPop above
1344-
if (block) {
1360+
if (block && block->name.hasSubstring("-pop-")) {
13451361
auto pop = block->list[0]->dynCast<Pop>();
1346-
if (pop) {
1347-
set->value = pop;
1348-
replaceCurrent(builder->makeSequence(set, builder->makeInc()));
1349-
}
1362+
assert(pop && getCatchCount(pop) == 1);
1363+
set->value = pop;
1364+
replaceCurrent(builder->makeBlock(
1365+
block->name, {set, builder->makeInc()}, Type::none));
13501366
}
13511367
}
13521368

1353-
// When return happens we decrement counter on amount of nested catch
1354-
// blocks up to root catch
1369+
// When return happens we decrement counter on 1, because we account
1370+
// only top-level catch blocks
13551371
// catch
13561372
// +1
13571373
// catch
1358-
// +1
1359-
// ...
1360-
// -2
1374+
// ;; not counted
1375+
// -1
13611376
// return
13621377
// ...
13631378
void visitReturn(Return* ret) {
1364-
auto it = expressionNestedLevel->find(ret);
1365-
assert(it != expressionNestedLevel->end());
1366-
auto nestedLevel = it->second;
1367-
if (nestedLevel > 0) {
1368-
replaceCurrent(
1369-
builder->makeSequence(builder->makeDec(nestedLevel), ret));
1379+
if (getCatchCount(ret) > 0) {
1380+
replaceCurrent(builder->makeSequence(builder->makeDec(), ret));
13701381
}
13711382
}
13721383

1373-
// When break happens we decrement counter on amount of nested catch
1374-
// blocks up to label
1384+
// When break happens we decrement counter only if it goes out
1385+
// from top-level catch block
13751386
void visitBreak(Break* br) {
1376-
auto it = expressionNestedLevel->find(br);
1377-
assert(it != expressionNestedLevel->end());
1378-
auto nestedLevel = it->second;
1379-
13801387
Expression* target = branchTargets->getTarget(br->name);
13811388
assert(target != nullptr);
1382-
1383-
it = expressionNestedLevel->find(target);
1384-
assert(it != expressionNestedLevel->end());
1385-
1386-
auto amount = nestedLevel - it->second;
1387-
assert(amount >= 0);
1388-
1389-
if (amount > 0) {
1389+
if (getCatchCount(br) > 0 && getCatchCount(target) == 0) {
13901390
if (br->condition == nullptr) {
1391-
replaceCurrent(builder->makeSequence(builder->makeDec(amount), br));
1392-
} else {
1393-
auto decIf = builder->makeIf(
1394-
br->condition,
1395-
builder->makeSequence(builder->makeDec(amount), br),
1396-
br->value);
1391+
replaceCurrent(builder->makeSequence(builder->makeDec(), br));
1392+
} else if (br->value == nullptr) {
1393+
auto decIf =
1394+
builder->makeIf(br->condition,
1395+
builder->makeSequence(builder->makeDec(), br),
1396+
nullptr);
13971397
br->condition = nullptr;
13981398
replaceCurrent(decIf);
1399+
} else {
1400+
Index newLocal = builder->addVar(func, br->value->type);
1401+
auto setLocal = builder->makeLocalSet(newLocal, br->value);
1402+
auto getLocal = builder->makeLocalGet(newLocal, br->value->type);
1403+
auto condition = br->condition;
1404+
br->condition = nullptr;
1405+
br->value = getLocal;
1406+
auto decIf =
1407+
builder->makeIf(condition,
1408+
builder->makeSequence(builder->makeDec(), br),
1409+
getLocal);
1410+
replaceCurrent(builder->makeSequence(setLocal, decIf));
13991411
}
14001412
}
14011413
}
14021414

1403-
// Replacing each catch block with try/finally and increase counter for
1404-
// catch_all blocks (not handled by visitPop); dec counter at the end
1405-
// of catch block
1406-
// try {fn}-finally-{label}
1415+
// Replacing each top-level catch block with try/catch_all(finally) and
1416+
// increase counter for catch_all blocks (not handled by visitPop); dec
1417+
// counter at the end of catch block try ({fn}-finally-{label})
14071418
// +1
14081419
// {catch body}
14091420
// -1
1410-
// catch
1421+
// catch_all
14111422
// -1
14121423
// rethrow {fn}-finally-{label}
14131424
void visitTry(Try* curr) {
1414-
for (size_t i = 0; i < curr->catchBodies.size(); ++i) {
1415-
curr->catchBodies[i] =
1416-
addCatchCounters(curr->catchBodies[i], i == curr->catchTags.size());
1425+
if (getCatchCount(curr) == 0) {
1426+
for (size_t i = 0; i < curr->catchBodies.size(); ++i) {
1427+
curr->catchBodies[i] = addCatchCounters(
1428+
curr->catchBodies[i], i == curr->catchTags.size());
1429+
}
14171430
}
14181431
}
14191432
Expression* addCatchCounters(Expression* expression, bool catchAll) {
1420-
// catch_all case is not covered by PopWalker
1433+
auto block = expression->dynCast<Block>();
1434+
if (block == nullptr) {
1435+
block = builder->makeBlock(expression);
1436+
}
1437+
1438+
// catch_all case is not covered by visitPop
14211439
if (catchAll) {
1422-
auto block = expression->dynCast<Block>();
1423-
assert(block != nullptr);
14241440
block->list.insertAt(0, builder->makeInc());
14251441
}
14261442

14271443
// dec counters at the end of catch
1428-
if (expression->type == Type::none) {
1429-
if (auto block = expression->dynCast<Block>()) {
1430-
auto last = block->list[block->list.size() - 1];
1431-
if (!last->dynCast<Return>()) {
1432-
block->list.push_back(builder->makeDec());
1433-
block->finalize();
1434-
}
1435-
} else {
1436-
WASM_UNREACHABLE("Unexpected expression type");
1444+
if (block->type == Type::none) {
1445+
auto last = block->list[block->list.size() - 1];
1446+
if (!last->dynCast<Return>()) {
1447+
block->list.push_back(builder->makeDec());
1448+
block->finalize();
14371449
}
14381450
}
14391451

14401452
auto name =
1441-
func->name.toString() + "-finally-" + std::to_string(++labelNum);
1453+
func->name.toString() + "-finally-" + std::to_string(++finallyNum);
14421454
return builder->makeTry(
14431455
name,
1444-
expression,
1456+
block,
14451457
{},
14461458
{builder->makeSequence(builder->makeDec(),
14471459
builder->makeRethrow(name))},
1448-
expression->type);
1460+
block->type);
14491461
}
14501462
};
14511463

14521464
AddCountersWalker addCountersWalker;
14531465
addCountersWalker.func = func;
14541466
addCountersWalker.builder = &builder;
14551467
addCountersWalker.branchTargets = &branchTargets;
1456-
addCountersWalker.expressionNestedLevel = &expressionNestedLevel;
1468+
addCountersWalker.expressionCatchCount = &expressionCatchCount;
14571469
addCountersWalker.walk(func->body);
14581470

14591471
EHUtils::handleBlockNestedPops(func, *module_);

0 commit comments

Comments
 (0)