Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Android.mk
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ LOCAL_SRC_FILES:= \
src/descriptor_set_and_binding_parser.cc \
src/engine.cc \
src/executor.cc \
src/float16_helper.cc \
src/format.cc \
src/parser.cc \
src/pipeline.cc \
Expand Down
7 changes: 5 additions & 2 deletions docs/amber_script.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ DEVICE_FEATURE VariablePointerFeatures.variablePointersStorageBuffer
```

Currently each of the items in `VkPhysicalDeviceFeatures` are recognized along
with `VariablePointerFeatures.variablePointers` and
`VariablePointerFeatures.variablePointersStorageBuffer`.
with:
* `VariablePointerFeatures.variablePointers`
* `VariablePointerFeatures.variablePointersStorageBuffer`
* `Float16Int8Features.shaderFloat16`

Extensions can be enabled with the `DEVICE_EXTENSION` and `INSTANCE_EXTENSION`
commands.
Expand Down Expand Up @@ -114,6 +116,7 @@ either image buffers or, what the target API would refer to as a buffer.
* `uint16`
* `uint32`
* `uint64`
* `float16`
* `float`
* `double`
* vec[2,3,4]{type}
Expand Down
22 changes: 17 additions & 5 deletions samples/config_helper_vulkan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ const size_t kNumberOfRequiredValidationLayers =
const char kVariablePointers[] = "VariablePointerFeatures.variablePointers";
const char kVariablePointersStorageBuffer[] =
"VariablePointerFeatures.variablePointersStorageBuffer";
const char kFloat16Int8_Float16[] = "Float16Int8Features.shaderFloat16";

const char kExtensionForValidationLayer[] = "VK_EXT_debug_report";

Expand Down Expand Up @@ -598,8 +599,8 @@ std::string deviceTypeToName(VkPhysicalDeviceType type) {
ConfigHelperVulkan::ConfigHelperVulkan()
: available_features_(VkPhysicalDeviceFeatures()),
available_features2_(VkPhysicalDeviceFeatures2KHR()),
variable_pointers_feature_(VkPhysicalDeviceVariablePointerFeaturesKHR()) {
}
variable_pointers_feature_(VkPhysicalDeviceVariablePointerFeaturesKHR()),
float16_int8_feature_(VkPhysicalDeviceFloat16Int8FeaturesKHR()) {}

ConfigHelperVulkan::~ConfigHelperVulkan() {
if (vulkan_device_)
Expand Down Expand Up @@ -666,9 +667,10 @@ amber::Result ConfigHelperVulkan::CreateVulkanInstance(

// Determine if VkPhysicalDeviceProperties2KHR should be used
for (auto& ext : required_extensions) {
if (ext == "VK_KHR_get_physical_device_properties2") {
if (ext == "VK_KHR_get_physical_device_properties2")
supports_get_physical_device_properties2_ = true;
}
if (ext == "VK_KHR_shader_float16_int8")
supports_shader_float16_int8_ = true;
}

std::vector<const char*> required_extensions_in_char;
Expand Down Expand Up @@ -882,9 +884,17 @@ amber::Result ConfigHelperVulkan::CreateDeviceWithFeatures1(
amber::Result ConfigHelperVulkan::CreateDeviceWithFeatures2(
const std::vector<std::string>& required_features,
VkDeviceCreateInfo* info) {
float16_int8_feature_.sType =
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FLOAT16_INT8_FEATURES_KHR;
float16_int8_feature_.pNext = nullptr;

variable_pointers_feature_.sType =
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VARIABLE_POINTER_FEATURES_KHR;
variable_pointers_feature_.pNext = nullptr;

if (supports_shader_float16_int8_)
variable_pointers_feature_.pNext = &float16_int8_feature_;
else
variable_pointers_feature_.pNext = nullptr;

available_features2_.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2_KHR;
available_features2_.pNext = &variable_pointers_feature_;
Expand All @@ -901,6 +911,8 @@ amber::Result ConfigHelperVulkan::CreateDeviceWithFeatures2(
variable_pointers_feature_.variablePointers = VK_TRUE;
else if (feature == kVariablePointersStorageBuffer)
variable_pointers_feature_.variablePointersStorageBuffer = VK_TRUE;
else if (feature == kFloat16Int8_Float16)
float16_int8_feature_.shaderFloat16 = VK_TRUE;
}

VkPhysicalDeviceFeatures required_vulkan_features =
Expand Down
2 changes: 2 additions & 0 deletions samples/config_helper_vulkan.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,11 @@ class ConfigHelperVulkan : public ConfigHelperImpl {
VkDevice vulkan_device_ = VK_NULL_HANDLE;

bool supports_get_physical_device_properties2_ = false;
bool supports_shader_float16_int8_ = false;
VkPhysicalDeviceFeatures available_features_;
VkPhysicalDeviceFeatures2KHR available_features2_;
VkPhysicalDeviceVariablePointerFeaturesKHR variable_pointers_feature_;
VkPhysicalDeviceFloat16Int8FeaturesKHR float16_int8_feature_;
};

} // namespace sample
Expand Down
2 changes: 2 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ set(AMBER_SOURCES
descriptor_set_and_binding_parser.cc
engine.cc
executor.cc
float16_helper.cc
format.cc
parser.cc
pipeline.cc
Expand Down Expand Up @@ -138,6 +139,7 @@ if (${AMBER_ENABLE_TESTS})
command_data_test.cc
descriptor_set_and_binding_parser_test.cc
executor_test.cc
float16_helper_test.cc
format_test.cc
pipeline_test.cc
result_test.cc
Expand Down
2 changes: 2 additions & 0 deletions src/amberscript/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ std::unique_ptr<type::Type> ToType(const std::string& str) {
return parser.Parse("R32_UINT");
if (str == "uint64")
return parser.Parse("R64_UINT");
if (str == "float16")
return parser.Parse("R16_SFLOAT");
if (str == "float")
return parser.Parse("R32_SFLOAT");
if (str == "double")
Expand Down
6 changes: 4 additions & 2 deletions src/amberscript/parser_device_feature_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,20 @@ using AmberScriptParserTest = testing::Test;
TEST_F(AmberScriptParserTest, DeviceFeature) {
std::string in = R"(
DEVICE_FEATURE vertexPipelineStoresAndAtomics
DEVICE_FEATURE VariablePointerFeatures.variablePointersStorageBuffer)";
DEVICE_FEATURE VariablePointerFeatures.variablePointersStorageBuffer
DEVICE_FEATURE Float16Int8Features.shaderFloat16)";

Parser parser;
Result r = parser.Parse(in);
ASSERT_TRUE(r.IsSuccess()) << r.Error();

auto script = parser.GetScript();
const auto& features = script->GetRequiredFeatures();
ASSERT_EQ(2U, features.size());
ASSERT_EQ(3U, features.size());
EXPECT_EQ("vertexPipelineStoresAndAtomics", features[0]);
EXPECT_EQ("VariablePointerFeatures.variablePointersStorageBuffer",
features[1]);
EXPECT_EQ("Float16Int8Features.shaderFloat16", features[2]);
}

TEST_F(AmberScriptParserTest, DeviceFeatureMissingFeature) {
Expand Down
38 changes: 6 additions & 32 deletions src/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,37 +19,11 @@
#include <cmath>
#include <cstring>

#include "src/float16_helper.h"

namespace amber {
namespace {

// Return sign value of 32 bits float.
uint16_t FloatSign(const uint32_t hex_float) {
return static_cast<uint16_t>(hex_float >> 31U);
}

// Return exponent value of 32 bits float.
uint16_t FloatExponent(const uint32_t hex_float) {
uint32_t exponent = ((hex_float >> 23U) & ((1U << 8U) - 1U)) - 112U;
const uint32_t half_exponent_mask = (1U << 5U) - 1U;
assert(((exponent & ~half_exponent_mask) == 0U) && "Float exponent overflow");
return static_cast<uint16_t>(exponent & half_exponent_mask);
}

// Return mantissa value of 32 bits float. Note that mantissa for 32
// bits float is 23 bits and this method must return uint32_t.
uint32_t FloatMantissa(const uint32_t hex_float) {
return static_cast<uint32_t>(hex_float & ((1U << 23U) - 1U));
}

// Convert 32 bits float |value| to 16 bits float based on IEEE-754.
uint16_t FloatToHexFloat16(const float value) {
const uint32_t* hex = reinterpret_cast<const uint32_t*>(&value);
return static_cast<uint16_t>(
static_cast<uint16_t>(FloatSign(*hex) << 15U) |
static_cast<uint16_t>(FloatExponent(*hex) << 10U) |
static_cast<uint16_t>(FloatMantissa(*hex) >> 13U));
}

template <typename T>
T* ValuesAs(uint8_t* values) {
return reinterpret_cast<T*>(values);
Expand Down Expand Up @@ -82,10 +56,10 @@ double CalculateDiff(const Format::Segment* seg,
return Sub<uint32_t>(buf1, buf2);
if (type::Type::IsUint64(mode, num_bits))
return Sub<uint64_t>(buf1, buf2);
// TODO(dsinclair): Handle float16 ...
if (type::Type::IsFloat16(mode, num_bits)) {
assert(false && "Float16 suppport not implemented");
return 0.0;
float val1 = float16::HexFloatToFloat(buf1, 16);
float val2 = float16::HexFloatToFloat(buf2, 16);
return static_cast<double>(val1 - val2);
}
if (type::Type::IsFloat32(mode, num_bits))
return Sub<float>(buf1, buf2);
Expand Down Expand Up @@ -399,7 +373,7 @@ uint32_t Buffer::WriteValueFromComponent(const Value& value,
return sizeof(uint64_t);
}
if (type::Type::IsFloat16(mode, num_bits)) {
*(ValuesAs<uint16_t>(ptr)) = FloatToHexFloat16(value.AsFloat());
*(ValuesAs<uint16_t>(ptr)) = float16::FloatToHexFloat16(value.AsFloat());
return sizeof(uint16_t);
}
if (type::Type::IsFloat32(mode, num_bits)) {
Expand Down
24 changes: 24 additions & 0 deletions src/buffer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <limits>
#include "gtest/gtest.h"
#include "src/float16_helper.h"
#include "src/type_parser.h"

namespace amber {
Expand Down Expand Up @@ -294,4 +295,27 @@ TEST_F(BufferTest, CompareHistogramEMDToleranceAllWhite) {
EXPECT_TRUE(b1.CompareHistogramEMD(&b2, 0.0f).IsSuccess());
}

TEST_F(BufferTest, SetFloat16) {
std::vector<Value> values;
values.resize(2);
values[0].SetDoubleValue(2.8);
values[1].SetDoubleValue(1234.567);

TypeParser parser;
auto type = parser.Parse("R16_SFLOAT");

Format fmt(type.get());
Buffer b;
b.SetFormat(&fmt);
b.SetData(std::move(values));

EXPECT_EQ(2, b.ElementCount());
EXPECT_EQ(2, b.ValueCount());
EXPECT_EQ(4, b.GetSizeInBytes());

auto v = b.GetValues<uint16_t>();
EXPECT_EQ(float16::FloatToHexFloat16(2.8f), v[0]);
EXPECT_EQ(float16::FloatToHexFloat16(1234.567f), v[1]);
}

} // namespace amber
126 changes: 126 additions & 0 deletions src/float16_helper.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// Copyright 2019 The Amber Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "src/float16_helper.h"

#include <cassert>

// Float10
// | 9 8 7 6 5 | 4 3 2 1 0 |
// | exponent | mantissa |
//
// Float11
// | 10 9 8 7 6 | 5 4 3 2 1 0 |
// | exponent | mantissa |
//
// Float16
// | 15 | 14 13 12 11 10 | 9 8 7 6 5 4 3 2 1 0 |
// | s | exponent | mantissa |
//
// Float32
// | 31 | 30 ... 23 | 22 ... 0 |
// | s | exponent | mantissa |

namespace amber {
namespace float16 {
namespace {

// Return sign value of 32 bits float.
uint16_t FloatSign(const uint32_t hex_float) {
return static_cast<uint16_t>(hex_float >> 31U);
}

// Return exponent value of 32 bits float.
uint16_t FloatExponent(const uint32_t hex_float) {
uint32_t exponent = ((hex_float >> 23U) & ((1U << 8U) - 1U)) - 112U;
const uint32_t half_exponent_mask = (1U << 5U) - 1U;
assert(((exponent & ~half_exponent_mask) == 0U) && "Float exponent overflow");
return static_cast<uint16_t>(exponent & half_exponent_mask);
}

// Return mantissa value of 32 bits float. Note that mantissa for 32
// bits float is 23 bits and this method must return uint32_t.
uint32_t FloatMantissa(const uint32_t hex_float) {
return static_cast<uint32_t>(hex_float & ((1U << 23U) - 1U));
}

// Convert float |value| whose size is 16 bits to 32 bits float
// based on IEEE-754.
float HexFloat16ToFloat(const uint8_t* value) {
uint32_t sign = (static_cast<uint32_t>(value[1]) & 0x80) << 24U;
uint32_t exponent = (((static_cast<uint32_t>(value[1]) & 0x7c) >> 2U) + 112U)
<< 23U;
uint32_t mantissa = ((static_cast<uint32_t>(value[1]) & 0x3) << 8U |
static_cast<uint32_t>(value[0]))
<< 13U;

uint32_t hex = sign | exponent | mantissa;
float* hex_float = reinterpret_cast<float*>(&hex);
return *hex_float;
}

// Convert float |value| whose size is 11 bits to 32 bits float
// based on IEEE-754.
float HexFloat11ToFloat(const uint8_t* value) {
uint32_t exponent = (((static_cast<uint32_t>(value[1]) << 2U) |
((static_cast<uint32_t>(value[0]) & 0xc0) >> 6U)) +
112U)
<< 23U;
uint32_t mantissa = (static_cast<uint32_t>(value[0]) & 0x3f) << 17U;

uint32_t hex = exponent | mantissa;
float* hex_float = reinterpret_cast<float*>(&hex);
return *hex_float;
}

// Convert float |value| whose size is 10 bits to 32 bits float
// based on IEEE-754.
float HexFloat10ToFloat(const uint8_t* value) {
uint32_t exponent = (((static_cast<uint32_t>(value[1]) << 3U) |
((static_cast<uint32_t>(value[0]) & 0xe0) >> 5U)) +
112U)
<< 23U;
uint32_t mantissa = (static_cast<uint32_t>(value[0]) & 0x1f) << 18U;

uint32_t hex = exponent | mantissa;
float* hex_float = reinterpret_cast<float*>(&hex);
return *hex_float;
}

} // namespace

float HexFloatToFloat(const uint8_t* value, uint8_t bits) {
switch (bits) {
case 10:
return HexFloat10ToFloat(value);
case 11:
return HexFloat11ToFloat(value);
case 16:
return HexFloat16ToFloat(value);
}

assert(false && "Invalid bits");
return 0;
}

uint16_t FloatToHexFloat16(const float value) {
const uint32_t* hex = reinterpret_cast<const uint32_t*>(&value);
return static_cast<uint16_t>(
static_cast<uint16_t>(FloatSign(*hex) << 15U) |
static_cast<uint16_t>(FloatExponent(*hex) << 10U) |
static_cast<uint16_t>(FloatMantissa(*hex) >> 13U));
}

} // namespace float16
} // namespace amber
Loading