Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 105 additions & 54 deletions src/meta_schedule/postproc/rewrite_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
#include <optional>
#include <unordered_set>

#include "../utils.h"
Expand All @@ -25,23 +26,15 @@ namespace tir {

/*!
* \brief Collect the block and index where the buffer is read.
* \note The buffers are expected to be read by only one BufferLoad
* \note The buffer is expected to be read by only one BufferLoad
*/
class BufferReadPosCollector : public StmtExprVisitor {
public:
explicit BufferReadPosCollector(const Array<Buffer>& buffers) {
for (const Buffer& buf : buffers) {
buffers_.insert(buf.get());
}
}
explicit BufferReadPosCollector(const Buffer& buffer) : buffer_(buffer.get()) {}

const std::unordered_map<const BufferNode*, std::pair<Block, int>>& GetBufferLocations() const {
return buffer_locs_;
}
const std::pair<Block, int>& GetBufferLocation() const { return buffer_loc_; }

const std::unordered_map<const BufferNode*, Optional<IndexMap>>& GetBufferIndexMap() const {
return buffer_index_maps_;
}
const Optional<IndexMap> GetBufferIndexMap() const { return buffer_index_map_; }

private:
void VisitStmt_(const ForNode* op) final {
Expand All @@ -61,7 +54,7 @@ class BufferReadPosCollector : public StmtExprVisitor {
CHECK(cur_realize_.defined()) << "BufferLoad occurred outside of any block";

const Buffer& buffer = op->buffer;
if (buffers_.count(buffer.get())) {
if (buffer_ == buffer.get()) {
Map<Var, PrimExpr> subst_map;
for (size_t i = 0; i < cur_realize_->iter_values.size(); i++) {
const Var& var = cur_realize_->block->iter_vars[i]->var;
Expand All @@ -72,14 +65,14 @@ class BufferReadPosCollector : public StmtExprVisitor {
for (const PrimExpr& e : op->indices) {
subst_indices.push_back(Substitute(e, subst_map));
}
buffer_index_maps_[buffer.get()] = SuggestIndexMap(/*buffer=*/buffer, //
/*indices=*/subst_indices, //
/*loops=*/loop_stack_, //
/*predicate=*/cur_realize_->predicate, //
/*analyzer=*/&analyzer_);
buffer_index_map_ = SuggestIndexMap(/*buffer=*/buffer, //
/*indices=*/subst_indices, //
/*loops=*/loop_stack_, //
/*predicate=*/cur_realize_->predicate, //
/*analyzer=*/&analyzer_);
int buffer_index = GetReadBufferIndex(cur_realize_->block, buffer);
ICHECK(buffer_index != -1);
buffer_locs_[buffer.get()] = std::make_pair(cur_realize_->block, buffer_index);
buffer_loc_ = std::make_pair(cur_realize_->block, buffer_index);
}
}

Expand All @@ -93,12 +86,12 @@ class BufferReadPosCollector : public StmtExprVisitor {
}

private:
/*! \brief All interested buffer. */
std::unordered_set<const BufferNode*> buffers_;
/*! \brief The result mapping from buffer to its inner-most block and read index. */
std::unordered_map<const BufferNode*, std::pair<Block, int>> buffer_locs_;
/*! \brief The result mapping from buffer to its IndexMap. */
std::unordered_map<const BufferNode*, Optional<IndexMap>> buffer_index_maps_;
/*! \brief The buffer of interest. */
const BufferNode* buffer_;
/*! \brief The block that consumes the buffer and the corresponding read index. */
std::pair<Block, int> buffer_loc_;
/*! \brief The proposed IndexMap. */
Optional<IndexMap> buffer_index_map_;

/*! \brief Loop stack for calculating IndexMap. */
Array<For> loop_stack_;
Expand Down Expand Up @@ -143,8 +136,56 @@ Array<Buffer> CollectLayoutFreeBuffers(const PrimFuncNode* func) {
return layout_free_buffers;
}

std::optional<std::tuple<Block, int, IndexMap>> GetSuggestedIndexMap(
Buffer buffer, const PrimFuncNode* prim_func) {
BufferReadPosCollector collector(buffer);
collector(prim_func->body);

const auto& index_map = collector.GetBufferIndexMap();

if (!index_map.defined() || !index_map) {
return std::nullopt;
}

const auto& [anchor_block, buffer_index] = collector.GetBufferLocation();

return std::make_tuple(anchor_block, buffer_index, index_map.value());
}

/*! \brief Get a chain of cache-read blocks, starting from the one consuming buf. */
std::vector<std::string> GetCacheReadChain(const Buffer& buf, const PrimFuncNode* prim_func) {
class BufferReadChainCollector : public StmtVisitor {
public:
explicit BufferReadChainCollector(const Buffer& buffer) : cur_buffer_(buffer.get()) {}

void VisitStmt_(const BlockNode* op) final {
// Check if this block is doing cache_read or a similar operation that consumes cur_buffer_.
if (!op->init && op->reads.size() == 1 && op->writes.size() == 1 &&
op->reads[0]->buffer.get() == cur_buffer_) {
cache_read_chain.push_back(op->name_hint);
cur_buffer_ = op->writes[0]->buffer.get();
}
StmtVisitor::VisitStmt_(op);
}

std::vector<std::string> cache_read_chain;

private:
const BufferNode* cur_buffer_;
};

BufferReadChainCollector collector(buf);
collector(prim_func->body);
return collector.cache_read_chain;
}

bool RewriteLayout(const Schedule& sch) {
std::vector<std::pair<StmtSRef, String>> results;
auto add_layout_rewrite_block = [&sch](BlockRV consumer_block_rv, int buffer_index) {
BlockRV rewrite_block_rv = sch->CacheRead(consumer_block_rv, buffer_index, "global");
sch->Annotate(rewrite_block_rv, attr::meta_schedule_layout_rewrite_preproc, const_true());
};

for (const auto& [g_var, base_func] : sch->mod()->functions) {
const String& func_name = g_var->name_hint;
const auto* prim_func = base_func.as<PrimFuncNode>();
Expand All @@ -153,36 +194,46 @@ bool RewriteLayout(const Schedule& sch) {
continue;
}

Array<Buffer> layout_free_buffers = CollectLayoutFreeBuffers(prim_func);

// Collect Buffer read positions
BufferReadPosCollector collector(layout_free_buffers);
collector(prim_func->body);
const auto& locations = collector.GetBufferLocations();
const auto& index_maps = collector.GetBufferIndexMap();
// Check all buffers are collected
if (locations.size() != layout_free_buffers.size() ||
index_maps.size() != layout_free_buffers.size()) {
return false;
}

for (const auto& kv : locations) {
const Buffer& buffer = GetRef<Buffer>(kv.first);
const Block& block = kv.second.first;
int buffer_index = kv.second.second;

// Get IndexMap
const Optional<IndexMap> index_map = index_maps.at(buffer.get());
if (!index_map.defined()) {
continue;
for (auto buffer : CollectLayoutFreeBuffers(prim_func)) {
const auto cache_read_chain = GetCacheReadChain(buffer, prim_func);
if (cache_read_chain.empty()) {
// The common case, where the layout-free buffer is directly consumed by an anchor op such
// as conv2d or dense.
auto tup_opt = GetSuggestedIndexMap(buffer, prim_func);
if (tup_opt == std::nullopt) continue;

auto [anchor_block, buffer_index, index_map] = *tup_opt;
auto anchor_block_rv = sch->GetBlock(anchor_block->name_hint, func_name);
add_layout_rewrite_block(anchor_block_rv, buffer_index);
sch->TransformLayout(anchor_block_rv, buffer_index, BufferIndexType::kRead, index_map,
NullOpt);
} else {
// When the layout-free buffer is consumed by cache_read, we need to find the index map
// for a cache-read buffer that is directly consumed by an anchor op. The last buffer
// in cache_read_chain corresponds to that buffer.
Block cache_read_block = sch->Get(sch->GetBlock(cache_read_chain.back(), func_name));
ICHECK_EQ(cache_read_block->writes.size(), 1);
auto tup_opt = GetSuggestedIndexMap(cache_read_block->writes[0]->buffer, prim_func);
if (tup_opt == std::nullopt) continue;

auto [anchor_block, buffer_index, index_map] = *tup_opt;
// Transform the layout of the last cache-read buffer.
sch->TransformLayout(sch->GetBlock(anchor_block->name_hint, func_name), buffer_index,
BufferIndexType::kRead, index_map, NullOpt);

// Propagate the layout transformation over cache_read_chain, starting from
// the next-to-last cache-read buffer.
for (int i = static_cast<int>(cache_read_chain.size()) - 1; i >= 0; --i) {
BlockRV cache_read_block_rv = sch->GetBlock(cache_read_chain[i], func_name);
if (i == 0) {
// Before the first cache_read that consumes the layout-free buffer, insert
// a layout-rewrite block. Another cache-read buffer is added, and its layout is
// transformed by TransformLayout below.
add_layout_rewrite_block(cache_read_block_rv, 0);
}
sch->TransformLayout(cache_read_block_rv, 0, BufferIndexType::kRead, index_map, NullOpt);
}
}

// Apply schedule
BlockRV block_rv = sch->GetBlock(block->name_hint, func_name);
BlockRV cached_block_rv = sch->CacheRead(block_rv, buffer_index, "global");
sch->TransformLayout(block_rv, buffer_index, BufferIndexType::kRead, index_map.value(),
NullOpt);
sch->Annotate(cached_block_rv, attr::meta_schedule_layout_rewrite_preproc, const_true());
}
}
return true;
Expand Down
9 changes: 8 additions & 1 deletion src/te/operation/create_primfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,20 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator {
Block block = Downcast<Block>(StmtMutator::VisitStmt_(_block));
BlockNode* n = block.CopyOnWrite();
if (Optional<ObjectRef> ann = n->annotations.Get(topi_attr)) {
Array<Buffer> new_buffers;
for (Buffer buffer : Downcast<Array<Buffer>>(ann)) {
auto it = buffer2index_.find(buffer);
if (it != buffer2index_.end()) {
layout_free_buffer_indices_.insert(it->second);
} else {
new_buffers.push_back(buffer);
}
}
n->annotations.erase(topi_attr);
if (new_buffers.empty()) {
n->annotations.erase(topi_attr);
} else {
n->annotations.Set(topi_attr, new_buffers);
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a fix for RewriteLayout to work when link-params = True, the breakage discussed in #13195. With this, I confirmed that the new RewriteLayout implementation works when there is CacheRead acting on AllocateConst. I'm not adding a test since TVMScript parsing is broken when there is AllocateConst with a large constant (can't print the whole array).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Continuing the discussion from this thread.

I am convinced that we need to store buffers in block attributes for hexagon-specific usecases, and the only thing that I feel less natural is generally storing IR nodes in attributes. As an example, in layout_free_placeholders, we didn't store any IR nodes in the attribute, but instead use a list of integers which is simpler.

In the meantime, I completely understand that we need to get around this quickly, so in this particular case, how about adding or reusing a pass config flag, e.g. the one we are using "link-params", and only add enable topi_attr to be set when the flag is on?

else if (link_param is on) {
  n->annotations.Set(topi_attr, new_buffers);
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative idea: having an extra flag in CreatePrimFunc that instructs the translator to retain buffers in block annotation, and the flag is on only if link-param is detected as "on" in TECompiler.

Copy link
Member Author

@masahi masahi Nov 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This else branch is only hit when there is a layout-free buffer that's not passed as a parameter to the prim func. Currently that can only happen when link-params = True (this is exactly what link-params is meant for). So isn't having another flag redundant?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. That makes perfect sense to me :-) As long as it's not affecting normal lowering flow, it looks like a good idea!

}
for (const String& attr : this->blocklist) {
auto it = n->annotations.find(attr);
Expand Down
Loading