Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions include/tvm/relay/attrs/bitserial.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relay/attrs/bitserial.h
* \brief Auxiliary attributes for bitserial operators.
*/

#ifndef TVM_RELAY_ATTRS_BITSERIAL_H_
#define TVM_RELAY_ATTRS_BITSERIAL_H_

#include <tvm/attrs.h>
#include <tvm/relay/base.h>
#include <string>

namespace tvm {
namespace relay {

/*! \brief Attributes used in bitpack operators */
struct BitPackAttrs : public tvm::AttrsNode<BitPackAttrs> {
int bits;
int pack_axis;
int bit_axis;
DataType pack_type;
std::string name;

TVM_DECLARE_ATTRS(BitPackAttrs, "relay.attrs.BitPackAttrs") {
TVM_ATTR_FIELD(bits).set_default(1).describe("Number of bits to quantize with.");
TVM_ATTR_FIELD(pack_axis).set_default(1).describe(
"Axis that should be compressed, typically channels.");
TVM_ATTR_FIELD(bit_axis).set_default(-1).describe("New axis for packed bits.");
TVM_ATTR_FIELD(pack_type)
.set_default(NullValue<DataType>())
.describe("Type of int to pack bits into.");
TVM_ATTR_FIELD(name).set_default("BitPack").describe("Name of operation.");
}
};

/*! \brief Attribues used in bitserial convolution operators */
struct BinaryConv2DAttrs : public tvm::AttrsNode<BinaryConv2DAttrs> {
Array<IndexExpr> strides;
Array<IndexExpr> padding;
IndexExpr channels;
Array<IndexExpr> kernel_size;
int activation_bits;
int weight_bits;
std::string data_layout;
std::string kernel_layout;
DataType pack_dtype;
DataType out_dtype;
bool unipolar;

TVM_DECLARE_ATTRS(BinaryConv2DAttrs, "relay.attrs.BinaryConv2DAttrs") {
TVM_ATTR_FIELD(strides)
.set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding)
.set_default(Array<IndexExpr>({0, 0}))
.describe(
"If padding is non-zero the input is implicitly zero-padded"
"on both sides for padding number of points.");
TVM_ATTR_FIELD(kernel_size)
.set_default(Array<IndexExpr>({3, 3}))
.describe("Specifies the dimensions of the convolution window.");
TVM_ATTR_FIELD(channels)
.set_default(NullValue<IndexExpr>())
.describe("Number of output channels, needed for shape inference.");
TVM_ATTR_FIELD(activation_bits)
.set_default(1)
.describe("Number of bits activation should be packed with.");
TVM_ATTR_FIELD(weight_bits)
.set_default(1)
.describe("Number of bits kernel should be packed with.");
TVM_ATTR_FIELD(data_layout)
.set_default("NCHW")
.describe("Dimension ordering of input data, can be 'NCHW' or NHWC'.");
TVM_ATTR_FIELD(kernel_layout)
.set_default("OIHW")
.describe("Dimension ordering of kernel data, can be 'OIHW' or HWIO'.");
TVM_ATTR_FIELD(pack_dtype)
.set_default(NullValue<DataType>())
.describe("Datatype to pack bits into.");
TVM_ATTR_FIELD(out_dtype).set_default(NullValue<DataType>()).describe("Output datatype.");
TVM_ATTR_FIELD(unipolar).set_default(true).describe(
"Whether to use unipolar or bipolar quantization.");
}
};

/*~ \brief Attributes for bitserial dense operator */
struct BinaryDenseAttrs : public tvm::AttrsNode<BinaryDenseAttrs> {
IndexExpr units;
int data_bits;
int weight_bits;
DataType pack_dtype;
DataType out_dtype;
bool unipolar;

TVM_DECLARE_ATTRS(BinaryDenseAttrs, "relay.attrs.BinaryDenseAttrs") {
TVM_ATTR_FIELD(units)
.describe("Number of hidden units of the dense transformation.");
TVM_ATTR_FIELD(data_bits)
.set_default(1)
.describe("Number of bits to pack for incoming tensor.");
TVM_ATTR_FIELD(weight_bits)
.set_default(1)
.describe("Number of bits to pack for weight tensor.");
TVM_ATTR_FIELD(pack_dtype)
.set_default(NullValue<DataType>())
.describe("Datatype to pack bits into before computation.");
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type.");
TVM_ATTR_FIELD(unipolar)
.set_default(true)
.describe("Whether to use unipolar or bipolar quantization for inputs.");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_BITSERIAL_H_
117 changes: 117 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,3 +600,120 @@ def schedule_deformable_conv2d(attrs, outs, target):


reg.register_pattern("nn.deformable_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)


@reg.register_compute("nn.bitpack")
def compute_bitpack(attrs, inputs, out_dtype, target):
"""Compute definition for bitpack"""
bits = attrs.bits
pack_axis = attrs.pack_axis
bit_axis = attrs.bit_axis
pack_type = attrs.pack_type
name = attrs.name
with target:
out = topi.nn.bitpack(inputs[0], bits, pack_axis, bit_axis, pack_type,
name)
return [out]

@reg.register_schedule("nn.bitpack")
def schedule_bitpack(attrs, outs, target):
with target:
return topi.generic.schedule_bitpack(outs)

reg.register_pattern("nn.bitpack", OpPattern.INJECTIVE)


@reg.register_compute("nn.bitserial_conv2d")
def compute_bitserial_conv2d(attrs, inputs, out_dtype, target):
"""Compute definition for bitserial conv2d."""
padding = get_const_tuple(attrs.padding)
strides = get_const_tuple(attrs.strides)
activation_bits = attrs.activation_bits
weight_bits = attrs.weight_bits
layout = attrs.data_layout
pack_dtype = attrs.pack_dtype
out_dtype = attrs.out_dtype
unipolar = attrs.unipolar
if layout == 'NCHW':
with target:
out = topi.nn.bitserial_conv2d_nchw(
inputs[0], inputs[1], strides, padding, activation_bits,
weight_bits, pack_dtype, out_dtype, unipolar)
elif layout == 'NHWC':
with target:
out = topi.nn.bitserial_conv2d_nhwc(
inputs[0], inputs[1], strides, padding, activation_bits,
weight_bits, pack_dtype, out_dtype, unipolar)
else:
raise ValueError("Data layout not supported.")

return [out]


@reg.register_schedule("nn.bitserial_conv2d")
def schedule_bitserial_conv2d(attrs, outs, target):
"""Schedule definition for bitserial conv2d."""
layout = attrs.data_layout
if layout == 'NCHW':
with target:
return topi.generic.schedule_bitserial_conv2d_nchw(outs)
elif layout == 'NHWC':
with target:
return topi.generic.schedule_bitserial_conv2d_nhwc(outs)
else:
raise ValueError("Data layout not supported.")

@reg.register_legalize("nn.bitserial_conv2d")
def legalize_bitserial_conv2d(attrs, inputs, types):
"""Legalize bitserial_conv2d op.

Parameters
----------
attrs : tvm.attrs.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types

Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
return topi.nn.bitserial_conv2d_legalize(attrs, inputs, types)


reg.register_pattern("nn.bitserial_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)


# bitserial_dense
@reg.register_compute("nn.bitserial_dense")
def compute_bitserial_dense(attrs, inputs, out_type, target):
"""Compute definition of bitserial_dense"""
data_bits = attrs.data_bits
weight_bits = attrs.weight_bits
pack_dtype = attrs.pack_dtype
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
unipolar = attrs.unipolar
return [
topi.nn.bitserial_dense(
inputs[0],
inputs[1],
data_bits,
weight_bits,
pack_dtype,
out_dtype,
unipolar)
]


@reg.register_schedule("nn.bitserial_dense")
def schedule_bitserial_dense(attrs, outputs, target):
"""Schedule definition of bitserial_dense"""
with target:
return topi.generic.schedule_bitserial_dense(outputs)


reg.register_pattern("nn.bitserial_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
Loading