Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 34 additions & 23 deletions src/wasm-validator.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,26 @@
#define wasm_wasm_validator_h

#include <set>
#include <sstream>

#include "wasm.h"
#include "wasm-printing.h"

namespace wasm {

// Print anything that can be streamed to an ostream
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps these two could go in wasm-printing.h?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about that but they are kind of specific to the validator where we want to easily make everything print with the extra type info. Otherwise they sort of just duplicate the stream-related functions there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, i see. sounds good.

template <typename T>
inline std::ostream& printModuleComponent(T curr, std::ostream& stream) {
stream << curr << std::endl;
return stream;
}
// Specialization for Expressions to print type info too
template <>
inline std::ostream& printModuleComponent(Expression* curr, std::ostream& stream) {
WasmPrinter::printExpression(curr, stream, false, true) << std::endl;
return stream;
}

struct WasmValidator : public PostWalker<WasmValidator> {
bool valid = true;

Expand Down Expand Up @@ -123,6 +137,8 @@ struct WasmValidator : public PostWalker<WasmValidator> {
void visitSetLocal(SetLocal *curr);
void visitLoad(Load *curr);
void visitStore(Store *curr);
void visitAtomicRMW(AtomicRMW *curr);
void visitAtomicCmpxchg(AtomicCmpxchg *curr);
void visitBinary(Binary *curr);
void visitUnary(Unary *curr);
void visitSelect(Select* curr);
Expand All @@ -144,21 +160,22 @@ struct WasmValidator : public PostWalker<WasmValidator> {

// helpers
private:
std::ostream& fail();
template <typename T, typename S>
std::ostream& fail(T curr, S text);
std::ostream& printFailureHeader();

template<typename T>
bool shouldBeTrue(bool result, T curr, const char* text) {
if (!result) {
fail() << "unexpected false: " << text << ", on \n" << curr << std::endl;
valid = false;
fail(curr, "unexpected false: " + std::string(text));
return false;
}
return result;
}
template<typename T>
bool shouldBeFalse(bool result, T curr, const char* text) {
if (result) {
fail() << "unexpected true: " << text << ", on \n" << curr << std::endl;
valid = false;
fail(curr, "unexpected true: " + std::string(text));
return false;
}
return result;
Expand All @@ -167,18 +184,9 @@ struct WasmValidator : public PostWalker<WasmValidator> {
template<typename T, typename S>
bool shouldBeEqual(S left, S right, T curr, const char* text) {
if (left != right) {
fail() << "" << left << " != " << right << ": " << text << ", on \n";
WasmPrinter::printExpression(curr, std::cerr, false, true) << std::endl;
valid = false;
return false;
}
return true;
}
template<typename T, typename S, typename U>
bool shouldBeEqual(S left, S right, T curr, U other, const char* text) {
if (left != right) {
fail() << "" << left << " != " << right << ": " << text << ", on \n" << curr << " / " << other << std::endl;
valid = false;
std::ostringstream ss;
ss << left << " != " << right << ": " << text;
fail(curr, ss.str());
return false;
}
return true;
Expand All @@ -187,9 +195,9 @@ struct WasmValidator : public PostWalker<WasmValidator> {
template<typename T, typename S>
bool shouldBeEqualOrFirstIsUnreachable(S left, S right, T curr, const char* text) {
if (left != unreachable && left != right) {
fail() << "" << left << " != " << right << ": " << text << ", on \n";
WasmPrinter::printExpression(curr, std::cerr, false, true) << std::endl;
valid = false;
std::ostringstream ss;
ss << left << " != " << right << ": " << text;
fail(curr, ss.str());
return false;
}
return true;
Expand All @@ -198,14 +206,17 @@ struct WasmValidator : public PostWalker<WasmValidator> {
template<typename T, typename S>
bool shouldBeUnequal(S left, S right, T curr, const char* text) {
if (left == right) {
fail() << "" << left << " == " << right << ": " << text << ", on \n" << curr << std::endl;
valid = false;
std::ostringstream ss;
ss << left << " == " << right << ": " << text;
fail(curr, ss.str());
return false;
}
return true;
}

void validateAlignment(size_t align, WasmType type, Index bytes);
void validateAlignment(size_t align, WasmType type, Index bytes, bool isAtomic,
Expression* curr);
void validateMemBytes(uint8_t bytes, WasmType ty, Expression* curr);
void validateBinaryenIR(Module& wasm);
};

Expand Down
52 changes: 42 additions & 10 deletions src/wasm/wasm-validator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,36 @@ void WasmValidator::visitSetLocal(SetLocal *curr) {
}
}
void WasmValidator::visitLoad(Load *curr) {
validateAlignment(curr->align, curr->type, curr->bytes);
validateMemBytes(curr->bytes, curr->type, curr);
validateAlignment(curr->align, curr->type, curr->bytes, curr->isAtomic, curr);
shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "load pointer type must be i32");
}
void WasmValidator::visitStore(Store *curr) {
validateAlignment(curr->align, curr->type, curr->bytes);
validateMemBytes(curr->bytes, curr->valueType, curr);
validateAlignment(curr->align, curr->type, curr->bytes, curr->isAtomic, curr);
shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "store pointer type must be i32");
shouldBeUnequal(curr->value->type, none, curr, "store value type must not be none");
shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->valueType, curr, "store value type must match");
}
void WasmValidator::visitAtomicRMW(AtomicRMW* curr) {
validateMemBytes(curr->bytes, curr->type, curr);
}
void WasmValidator::visitAtomicCmpxchg(AtomicCmpxchg* curr) {
validateMemBytes(curr->bytes, curr->type, curr);
}
void WasmValidator::validateMemBytes(uint8_t bytes, WasmType ty, Expression* curr) {
switch (bytes) {
case 1:
case 2:
case 4:
break;
case 8: {
shouldBeEqual(getWasmTypeSize(ty), 8U, curr, "8-byte mem operations are only allowed with 8-byte wasm types");
break;
}
default: fail("Memory operations must be 1,2,4, or 8 bytes", curr);
}
}
void WasmValidator::visitBinary(Binary *curr) {
if (curr->left->type != unreachable && curr->right->type != unreachable) {
shouldBeEqual(curr->left->type, curr->right->type, curr, "binary child types must be equal");
Expand Down Expand Up @@ -561,28 +582,32 @@ void WasmValidator::visitModule(Module *curr) {
}
}

void WasmValidator::validateAlignment(size_t align, WasmType type, Index bytes) {
void WasmValidator::validateAlignment(size_t align, WasmType type, Index bytes,
bool isAtomic, Expression* curr) {
if (isAtomic) {
shouldBeEqual(align, (size_t)bytes, curr, "atomic accesses must have natural alignment");
return;
}
switch (align) {
case 1:
case 2:
case 4:
case 8: break;
default:{
fail() << "bad alignment: " << align << std::endl;
valid = false;
fail("bad alignment: " + std::to_string(align), curr);
break;
}
}
shouldBeTrue(align <= bytes, align, "alignment must not exceed natural");
shouldBeTrue(align <= bytes, curr, "alignment must not exceed natural");
switch (type) {
case i32:
case f32: {
shouldBeTrue(align <= 4, align, "alignment must not exceed natural");
shouldBeTrue(align <= 4, curr, "alignment must not exceed natural");
break;
}
case i64:
case f64: {
shouldBeTrue(align <= 8, align, "alignment must not exceed natural");
shouldBeTrue(align <= 8, curr, "alignment must not exceed natural");
break;
}
default: {}
Expand All @@ -609,7 +634,7 @@ void WasmValidator::validateBinaryenIR(Module& wasm) {
// The block has an added type, not derived from the ast itself, so it is
// ok for it to be either i32 or unreachable.
if (!(isConcreteWasmType(oldType) && newType == unreachable)) {
parent.fail() << "stale type found in " << (getFunction() ? getFunction()->name : Name("(global scope)")) << " on " << curr << "\n(marked as " << printWasmType(oldType) << ", should be " << printWasmType(newType) << ")\n";
parent.printFailureHeader() << "stale type found in " << (getFunction() ? getFunction()->name : Name("(global scope)")) << " on " << curr << "\n(marked as " << printWasmType(oldType) << ", should be " << printWasmType(newType) << ")\n";
parent.valid = false;
}
curr->type = oldType;
Expand All @@ -620,7 +645,14 @@ void WasmValidator::validateBinaryenIR(Module& wasm) {
binaryenIRValidator.walkModule(&wasm);
}

std::ostream& WasmValidator::fail() {
template <typename T, typename S>
std::ostream& WasmValidator::fail(T curr, S text) {
valid = false;
auto& ret = printFailureHeader() << text << ", on \n";
return printModuleComponent(curr, ret);
}

std::ostream& WasmValidator::printFailureHeader() {
Colors::red(std::cerr);
if (getFunction()) {
std::cerr << "[wasm-validator error in function ";
Expand Down