diff --git a/src/amberscript/parser.cc b/src/amberscript/parser.cc index 8d124836d..953fc12b5 100644 --- a/src/amberscript/parser.cc +++ b/src/amberscript/parser.cc @@ -1024,6 +1024,9 @@ Result Parser::ParsePipelineSet(Pipeline* pipeline) { if (!type) return Result("invalid data type '" + token->AsString() + "' provided"); + if (type->IsVec() || type->IsMatrix() || type->IsArray() || type->IsStruct()) + return Result("data type must be a scalar type"); + token = tokenizer_->NextToken(); if (!token->IsInteger() && !token->IsDouble()) return Result("expected data value"); diff --git a/src/amberscript/parser_pipeline_set_test.cc b/src/amberscript/parser_pipeline_set_test.cc index 7d78ca739..46f0dd45f 100644 --- a/src/amberscript/parser_pipeline_set_test.cc +++ b/src/amberscript/parser_pipeline_set_test.cc @@ -241,5 +241,22 @@ END EXPECT_EQ("7: SET can only be used with OPENCL-C shaders", r.Error()); } +TEST_F(AmberScriptParserTest, OpenCLSetNonScalarDataType) { + std::string in = R"( +SHADER compute my_shader OPENCL-C +#shader +END +PIPELINE compute my_pipeline + ATTACH my_shader + SET KERNEL ARG_NAME arg_a AS vec4 0 0 0 0 +END +)"; + + Parser parser; + auto r = parser.Parse(in); + ASSERT_FALSE(r.IsSuccess()); + EXPECT_EQ("7: data type must be a scalar type", r.Error()); +} + } // namespace amberscript } // namespace amber diff --git a/src/pipeline.cc b/src/pipeline.cc index eca6326b7..478e6e179 100644 --- a/src/pipeline.cc +++ b/src/pipeline.cc @@ -755,7 +755,34 @@ Result Pipeline::GenerateOpenCLPodBuffers() { return Result(message); } - Result r = buffer->SetDataWithOffset({arg_info.value}, offset); + // Convert the argument value into bytes. Currently, only scalar arguments + // are supported. + const auto arg_byte_size = arg_info.fmt->SizeInBytes(); + std::vector data_bytes; + for (uint32_t i = 0; i < arg_byte_size; ++i) { + Value v; + if (arg_info.value.IsFloat()) { + if (arg_byte_size == sizeof(double)) { + union { + uint64_t u; + double d; + } u; + u.d = arg_info.value.AsDouble(); + v.SetIntValue((u.u >> (i * 8)) & 0xff); + } else { + union { + uint32_t u; + float f; + } u; + u.f = arg_info.value.AsFloat(); + v.SetIntValue((u.u >> (i * 8)) & 0xff); + } + } else { + v.SetIntValue((arg_info.value.AsUint64() >> (i * 8)) & 0xff); + } + data_bytes.push_back(v); + } + Result r = buffer->SetDataWithOffset(data_bytes, offset); if (!r.IsSuccess()) return r; }