Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions src/passes/Print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,22 +350,25 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
printFullLine(curr->value);
decIndent();
}
void visitAtomicRMW(AtomicRMW* curr) {
o << '(';
prepareColor(o) << printWasmType(curr->type) << ".atomic.rmw";
if (curr->bytes != getWasmTypeSize(curr->type)) {
if (curr->bytes == 1) {
static void printRMWSize(std::ostream& o, WasmType type, uint8_t bytes) {
prepareColor(o) << printWasmType(type) << ".atomic.rmw";
if (bytes != getWasmTypeSize(type)) {
if (bytes == 1) {
o << '8';
} else if (curr->bytes == 2) {
} else if (bytes == 2) {
o << "16";
} else if (curr->bytes == 4) {
} else if (bytes == 4) {
o << "32";
} else {
WASM_UNREACHABLE();
}
o << "_u";
}
o << '.';
}
void visitAtomicRMW(AtomicRMW* curr) {
o << '(';
printRMWSize(o, curr->type, curr->bytes);
switch (curr->op) {
case Add: o << "add"; break;
case Sub: o << "sub"; break;
Expand All @@ -383,6 +386,20 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
printFullLine(curr->value);
decIndent();
}
void visitAtomicCmpxchg(AtomicCmpxchg* curr) {
o << '(';
printRMWSize(o, curr->type, curr->bytes);
o << "cmpxchg";
restoreNormalColor(o);
if (curr->offset) {
o << " offset=" << curr->offset;
}
incIndent();
printFullLine(curr->ptr);
printFullLine(curr->expected);
printFullLine(curr->replacement);
decIndent();
}
void visitConst(Const *curr) {
o << curr->value;
}
Expand Down
12 changes: 12 additions & 0 deletions src/wasm-binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,16 @@ enum AtomicOpcodes {
I64AtomicRMWXchg16U = 0x46,
I64AtomicRMWXchg32U = 0x47,
AtomicRMWOps_End = 0x47,

AtomicCmpxchgOps_Begin = 0x48,
I32AtomicCmpxchg = 0x48,
I64AtomicCmpxchg = 0x49,
I32AtomicCmpxchg8U = 0x4a,
I32AtomicCmpxchg16U = 0x4b,
I64AtomicCmpxchg8U = 0x4c,
I64AtomicCmpxchg16U = 0x4d,
I64AtomicCmpxchg32U = 0x4e,
AtomicCmpxchgOps_End = 0x4e
};


Expand Down Expand Up @@ -723,6 +733,7 @@ class WasmBinaryWriter : public Visitor<WasmBinaryWriter, void> {
void visitLoad(Load *curr);
void visitStore(Store *curr);
void visitAtomicRMW(AtomicRMW *curr);
void visitAtomicCmpxchg(AtomicCmpxchg *curr);
void visitConst(Const *curr);
void visitUnary(Unary *curr);
void visitBinary(Binary *curr);
Expand Down Expand Up @@ -881,6 +892,7 @@ class WasmBinaryBuilder {
bool maybeVisitLoad(Expression*& out, uint8_t code, bool isAtomic);
bool maybeVisitStore(Expression*& out, uint8_t code, bool isAtomic);
bool maybeVisitAtomicRMW(Expression*& out, uint8_t code);
bool maybeVisitAtomicCmpxchg(Expression*& out, uint8_t code);
bool maybeVisitConst(Expression*& out, uint8_t code);
bool maybeVisitUnary(Expression*& out, uint8_t code);
bool maybeVisitBinary(Expression*& out, uint8_t code);
Expand Down
4 changes: 3 additions & 1 deletion src/wasm-s-parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,9 @@ class SExpressionWasmBuilder {
Expression* makeConst(Element& s, WasmType type);
Expression* makeLoad(Element& s, WasmType type, bool isAtomic);
Expression* makeStore(Element& s, WasmType type, bool isAtomic);
Expression* makeAtomicRMW(Element& s, WasmType type);
Expression* makeAtomicRMWOrCmpxchg(Element& s, WasmType type);
Expression* makeAtomicRMW(Element& s, WasmType type, uint8_t bytes, const char* extra);
Expression* makeAtomicCmpxchg(Element& s, WasmType type, uint8_t bytes, const char* extra);
Expression* makeIf(Element& s);
Expression* makeMaybeBlock(Element& s, size_t i, WasmType type);
Expression* makeLoop(Element& s);
Expand Down
11 changes: 11 additions & 0 deletions src/wasm-traversal.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ struct Visitor {
ReturnType visitLoad(Load* curr) {}
ReturnType visitStore(Store* curr) {}
ReturnType visitAtomicRMW(AtomicRMW* curr) {return ReturnType();} //Stub impl so not every pass has to implement this yet.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make all these functions have the return, I believe we decided that was best?

Also, please add spaces after { and before }.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a local branch with that change which I'll submit as a separate PR because it depends on this one (I already had this branch in flight when we decided that).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool, lgtm

ReturnType visitAtomicCmpxchg(AtomicCmpxchg* curr) {return ReturnType();} //Stub impl so not every pass has to implement this yet.
ReturnType visitConst(Const* curr) {}
ReturnType visitUnary(Unary* curr) {}
ReturnType visitBinary(Binary* curr) {}
Expand Down Expand Up @@ -92,6 +93,7 @@ struct Visitor {
case Expression::Id::LoadId: DELEGATE(Load);
case Expression::Id::StoreId: DELEGATE(Store);
case Expression::Id::AtomicRMWId: DELEGATE(AtomicRMW);
case Expression::Id::AtomicCmpxchgId: DELEGATE(AtomicCmpxchg);
case Expression::Id::ConstId: DELEGATE(Const);
case Expression::Id::UnaryId: DELEGATE(Unary);
case Expression::Id::BinaryId: DELEGATE(Binary);
Expand Down Expand Up @@ -133,6 +135,7 @@ struct UnifiedExpressionVisitor : public Visitor<SubType> {
ReturnType visitLoad(Load* curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
ReturnType visitStore(Store* curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
ReturnType visitAtomicRMW(AtomicRMW* curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
ReturnType visitAtomicCmpxchg(AtomicCmpxchg* curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
ReturnType visitConst(Const* curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
ReturnType visitUnary(Unary* curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
ReturnType visitBinary(Binary* curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
Expand Down Expand Up @@ -310,6 +313,7 @@ struct Walker : public VisitorType {
static void doVisitLoad(SubType* self, Expression** currp) { self->visitLoad((*currp)->cast<Load>()); }
static void doVisitStore(SubType* self, Expression** currp) { self->visitStore((*currp)->cast<Store>()); }
static void doVisitAtomicRMW(SubType* self, Expression** currp) { self->visitAtomicRMW((*currp)->cast<AtomicRMW>()); }
static void doVisitAtomicCmpxchg(SubType* self, Expression** currp){ self->visitAtomicCmpxchg((*currp)->cast<AtomicCmpxchg>()); }
static void doVisitConst(SubType* self, Expression** currp) { self->visitConst((*currp)->cast<Const>()); }
static void doVisitUnary(SubType* self, Expression** currp) { self->visitUnary((*currp)->cast<Unary>()); }
static void doVisitBinary(SubType* self, Expression** currp) { self->visitBinary((*currp)->cast<Binary>()); }
Expand Down Expand Up @@ -438,6 +442,13 @@ struct PostWalker : public Walker<SubType, VisitorType> {
self->pushTask(SubType::scan, &curr->cast<AtomicRMW>()->ptr);
break;
}
case Expression::Id::AtomicCmpxchgId: {
self->pushTask(SubType::doVisitAtomicCmpxchg, currp);
self->pushTask(SubType::scan, &curr->cast<AtomicCmpxchg>()->replacement);
self->pushTask(SubType::scan, &curr->cast<AtomicCmpxchg>()->expected);
self->pushTask(SubType::scan, &curr->cast<AtomicCmpxchg>()->ptr);
break;
}
case Expression::Id::ConstId: {
self->pushTask(SubType::doVisitConst, currp);
break;
Expand Down
11 changes: 10 additions & 1 deletion src/wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ enum HostOp {
};

enum AtomicRMWOp {
Add, Sub, And, Or, Xor, Xchg,
Add, Sub, And, Or, Xor, Xchg
};

//
Expand Down Expand Up @@ -445,6 +445,15 @@ class AtomicRMW : public SpecificExpression<Expression::AtomicRMWId> {
class AtomicCmpxchg : public SpecificExpression<Expression::AtomicCmpxchgId> {
public:
AtomicCmpxchg() = default;
AtomicCmpxchg(MixedArena& allocator) : AtomicCmpxchg() {}

uint8_t bytes;
Address offset;
Expression* ptr;
Expression* expected;
Expression* replacement;

void finalize();
};

class Const : public SpecificExpression<Expression::ConstId> {
Expand Down
64 changes: 64 additions & 0 deletions src/wasm/wasm-binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,37 @@ void WasmBinaryWriter::visitAtomicRMW(AtomicRMW *curr) {
emitMemoryAccess(curr->bytes, curr->bytes, curr->offset);
}

void WasmBinaryWriter::visitAtomicCmpxchg(AtomicCmpxchg *curr) {
if (debug) std::cerr << "zz node: AtomicCmpxchg" << std::endl;
recurse(curr->ptr);
recurse(curr->expected);
recurse(curr->replacement);

o << int8_t(BinaryConsts::AtomicPrefix);
switch (curr->type) {
case i32:
switch (curr->bytes) {
case 1: o << int8_t(BinaryConsts::I32AtomicCmpxchg8U); break;
case 2: o << int8_t(BinaryConsts::I32AtomicCmpxchg16U); break;
case 4: o << int8_t(BinaryConsts::I32AtomicCmpxchg); break;
default: WASM_UNREACHABLE();
}
break;
case i64:
switch (curr->bytes) {
case 1: o << int8_t(BinaryConsts::I64AtomicCmpxchg8U); break;
case 2: o << int8_t(BinaryConsts::I64AtomicCmpxchg16U); break;
case 4: o << int8_t(BinaryConsts::I64AtomicCmpxchg32U); break;
case 8: o << int8_t(BinaryConsts::I64AtomicCmpxchg); break;
default: WASM_UNREACHABLE();
}
break;
default: WASM_UNREACHABLE();
}
emitMemoryAccess(curr->bytes, curr->bytes, curr->offset);
}


void WasmBinaryWriter::visitConst(Const *curr) {
if (debug) std::cerr << "zz node: Const" << curr << " : " << curr->type << std::endl;
switch (curr->type) {
Expand Down Expand Up @@ -1980,6 +2011,7 @@ BinaryConsts::ASTNodes WasmBinaryBuilder::readExpression(Expression*& curr) {
if (maybeVisitLoad(curr, code, /*isAtomic=*/true)) break;
if (maybeVisitStore(curr, code, /*isAtomic=*/true)) break;
if (maybeVisitAtomicRMW(curr, code)) break;
if (maybeVisitAtomicCmpxchg(curr, code)) break;
throw ParseException("invalid code after atomic prefix: " + std::to_string(code));
}
default: {
Expand Down Expand Up @@ -2372,6 +2404,38 @@ bool WasmBinaryBuilder::maybeVisitAtomicRMW(Expression*& out, uint8_t code) {
return true;
}

bool WasmBinaryBuilder::maybeVisitAtomicCmpxchg(Expression*& out, uint8_t code) {
if (code < BinaryConsts::AtomicCmpxchgOps_Begin || code > BinaryConsts::AtomicCmpxchgOps_End) return false;
auto* curr = allocator.alloc<AtomicCmpxchg>();

// Set curr to the given type and size.
#define SET(optype, size) \
curr->type = optype; \
curr->bytes = size

switch (code) {
case BinaryConsts::I32AtomicCmpxchg: SET(i32, 4); break;
case BinaryConsts::I64AtomicCmpxchg: SET(i64, 8); break;
case BinaryConsts::I32AtomicCmpxchg8U: SET(i32, 1); break;
case BinaryConsts::I32AtomicCmpxchg16U: SET(i32, 2); break;
case BinaryConsts::I64AtomicCmpxchg8U: SET(i64, 1); break;
case BinaryConsts::I64AtomicCmpxchg16U: SET(i64, 2); break;
case BinaryConsts::I64AtomicCmpxchg32U: SET(i64, 4); break;
default: WASM_UNREACHABLE();
}

if (debug) std::cerr << "zz node: AtomicCmpxchg" << std::endl;
Address readAlign;
readMemoryAccess(readAlign, curr->bytes, curr->offset);
if (readAlign != curr->bytes) throw ParseException("Align of AtomicCpxchg must match size");
curr->replacement = popNonVoidExpression();
curr->expected = popNonVoidExpression();
curr->ptr = popNonVoidExpression();
curr->finalize();
out = curr;
return true;
}

bool WasmBinaryBuilder::maybeVisitConst(Expression*& out, uint8_t code) {
Const* curr;
if (debug) std::cerr << "zz node: Const, code " << code << std::endl;
Expand Down
30 changes: 25 additions & 5 deletions src/wasm/wasm-s-parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ Expression* SExpressionWasmBuilder::makeExpression(Element& s) {
if (op[1] == 't' && !strncmp(op, "atomic.", strlen("atomic."))) {
if (op[7] == 'l') return makeLoad(s, type, /*isAtomic=*/true);
if (op[7] == 's') return makeStore(s, type, /*isAtomic=*/true);
if (op[7] == 'r') return makeAtomicRMW(s, type);
if (op[7] == 'r') return makeAtomicRMWOrCmpxchg(s, type);
}
abort_on(op);
}
Expand Down Expand Up @@ -1197,14 +1197,20 @@ Expression* SExpressionWasmBuilder::makeStore(Element& s, WasmType type, bool is
return ret;
}

Expression* SExpressionWasmBuilder::makeAtomicRMW(Element& s, WasmType type) {
Expression* SExpressionWasmBuilder::makeAtomicRMWOrCmpxchg(Element& s, WasmType type) {
const char* extra = strchr(s[0]->c_str(), '.') + 11; // afer "type.atomic.rmw"
auto ret = allocator.alloc<AtomicRMW>();
ret->type = type;
ret->bytes = parseMemBytes(&extra, getWasmTypeSize(type));
auto bytes = parseMemBytes(&extra, getWasmTypeSize(type));
extra = strchr(extra, '.'); // after the optional '_u' and before the opcode
if (!extra) throw ParseException("malformed atomic rmw instruction");
extra++; // after the '.'
if (!strncmp(extra, "cmpxchg", 7)) return makeAtomicCmpxchg(s, type, bytes, extra);
return makeAtomicRMW(s, type, bytes, extra);
}

Expression* SExpressionWasmBuilder::makeAtomicRMW(Element& s, WasmType type, uint8_t bytes, const char* extra) {
auto ret = allocator.alloc<AtomicRMW>();
ret->type = type;
ret->bytes = bytes;
if (!strncmp(extra, "add", 3)) ret->op = Add;
else if (!strncmp(extra, "and", 3)) ret->op = And;
else if (!strncmp(extra, "or", 2)) ret->op = Or;
Expand All @@ -1221,6 +1227,20 @@ Expression* SExpressionWasmBuilder::makeAtomicRMW(Element& s, WasmType type) {
return ret;
}

Expression* SExpressionWasmBuilder::makeAtomicCmpxchg(Element& s, WasmType type, uint8_t bytes, const char* extra) {
auto ret = allocator.alloc<AtomicCmpxchg>();
ret->type = type;
ret->bytes = bytes;
Address align;
size_t i = parseMemAttributes(s, &ret->offset, &align, ret->bytes);
if (align != ret->bytes) throw ParseException("Align of Atomic Cmpxchg must match size");
ret->ptr = parseExpression(s[i]);
ret->expected = parseExpression(s[i+1]);
ret->replacement = parseExpression(s[i+2]);
ret->finalize();
return ret;
}

Expression* SExpressionWasmBuilder::makeIf(Element& s) {
auto ret = allocator.alloc<If>();
Index i = 1;
Expand Down
6 changes: 6 additions & 0 deletions src/wasm/wasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,12 @@ void AtomicRMW::finalize() {
}
}

void AtomicCmpxchg::finalize() {
if (ptr->type == unreachable || expected->type == unreachable || replacement->type == unreachable) {
type = unreachable;
}
}

Const* Const::set(Literal value_) {
value = value_;
type = value.type;
Expand Down
32 changes: 32 additions & 0 deletions test/atomics.wast
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,36 @@
)
)
)
(func $atomic-cmpxchg (type $0)
(local $0 i32)
(local $1 i32)
(drop
(i32.atomic.rmw.cmpxchg offset=4
(get_local $0)
(get_local $0)
(get_local $0)
)
)
(drop
(i32.atomic.rmw8_u.cmpxchg
(get_local $0)
(get_local $0)
(get_local $0)
)
)
(drop
(i64.atomic.rmw.cmpxchg offset=4
(get_local $0)
(get_local $0)
(get_local $0)
)
)
(drop
(i64.atomic.rmw32_u.cmpxchg align=4
(get_local $0)
(get_local $0)
(get_local $0)
)
)
)
)
32 changes: 32 additions & 0 deletions test/atomics.wast.from-wast
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,36 @@
)
)
)
(func $atomic-cmpxchg (type $0)
(local $0 i32)
(local $1 i32)
(drop
(i32.atomic.rmw.cmpxchg offset=4
(get_local $0)
(get_local $0)
(get_local $0)
)
)
(drop
(i32.atomic.rmw8_u.cmpxchg
(get_local $0)
(get_local $0)
(get_local $0)
)
)
(drop
(i64.atomic.rmw.cmpxchg offset=4
(get_local $0)
(get_local $0)
(get_local $0)
)
)
(drop
(i64.atomic.rmw32_u.cmpxchg
(get_local $0)
(get_local $0)
(get_local $0)
)
)
)
)
Loading