Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions include/TaskflowDialect/TaskflowOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -39,38 +39,40 @@ def TaskflowTaskOp : TaskflowOpBase<"task", [
1. Memory dependencies: memrefs that are read or written by the task
2. Value dependencies: SSA values from producer tasks

The `read_memrefs` and `write_memrefs` attributes record the actural
The `original_read_memrefs` and `original_write_memrefs` attributes record the actural
original memrefs that this task accesses,
enabling data placement analysis for multi-CGRA mapping.

Example:
// Memory inputs: %mem, Value inputs: %val
$out_mem, %out_val = taskflow.task "Task_0"
read_inputs(%mem : memref<4xi32>)
dependency_read_in(%mem : memref<4xi32>)
dependency_write_in(%arg5 : memref<?xi32>)
value_inputs(%val : i32)
original_read_memrefs(%arg0 : memref<?x8x6xi32>)
original_write_memrefs(%arg5 : memref<?xi32>) {
^bb0(%a0: memref<4xi32>, %a1: i32):
[original_read_memrefs(%arg0 : memref<?x8x6xi32>),
original_write_memrefs(%arg5 : memref<?xi32>)] {
^bb0(%a0: memref<4xi32>, %a1: memref<?xi32>, %a2: i32):
affine.for %i = 0 to 4 {
%v = affine.load %a0[%i] : memref<4xi32>
%sum = arith.addi %v, %a1 : i32
affine.store %sum, %a0[%i] : memref<4xi32>
%sum = arith.addi %v, %a2 : i32
affine.store %sum, %a1[%i] : memref<?xi32>
}
taskflow.yield memory_outputs(%a0 : memref<4xi32>) value_outputs(%a1 : i32)
} : (memref<4xi32>, i32) -> (memref<4xi32>, i32)
taskflow.yield reads(%a0 : memref<4xi32>) writes(%a1 : memref<?xi32>) values(%a2 : i32)
} : (memref<4xi32>, memref<?xi32>, i32) -> (memref<4xi32>, memref<?xi32>, i32)
}];

let arguments = (ins
Variadic<AnyMemRef>:$read_memrefs,
Variadic<AnyMemRef>:$write_memrefs,
Variadic<AnyMemRef>:$dependency_read_in,
Variadic<AnyMemRef>:$dependency_write_in,
Variadic<AnyType>:$value_inputs,
StrAttr:$task_name,
Variadic<AnyMemRef>:$original_read_memrefs,
Variadic<AnyMemRef>:$original_write_memrefs
);

let results = (outs
Variadic<AnyMemRef>:$write_outputs,
Variadic<AnyMemRef>:$dependency_read_out,
Variadic<AnyMemRef>:$dependency_write_out,
Variadic<AnyType>:$value_outputs
);

Expand All @@ -94,6 +96,7 @@ def TaskflowYieldOp : TaskflowOpBase<"yield", [Terminator, Pure, ReturnLike, Att
}];

let arguments = (ins
Variadic<AnyMemRef>:$read_results,
Variadic<AnyMemRef>:$memory_results,
Variadic<AnyType>:$value_results);

Expand All @@ -102,7 +105,7 @@ def TaskflowYieldOp : TaskflowOpBase<"yield", [Terminator, Pure, ReturnLike, Att
let builders = [
// Default builder for empty yield.
OpBuilder<(ins), [{
build($_builder, $_state, ValueRange{}, ValueRange{});
build($_builder, $_state, ValueRange{}, ValueRange{}, ValueRange{});
}]>
];
}
Expand Down
40 changes: 34 additions & 6 deletions lib/Conversion/AffineToTaskflow/AffineToTaskflowPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,12 @@ static TaskflowTaskOp convertLoopToTask(
//-------------------------------------------------------------------
// Step 5: Prepares output types.
//-------------------------------------------------------------------
// Read output types: passthrough read memrefs for WAR dependency tracking.
SmallVector<Type> read_output_types;
for (Value memref : read_memrefs) {
read_output_types.push_back(memref.getType());
}

SmallVector<Type> memory_output_types;
for (Value memref : output_memrefs) {
memory_output_types.push_back(memref.getType());
Expand All @@ -248,7 +254,8 @@ static TaskflowTaskOp convertLoopToTask(
//-------------------------------------------------------------------
TaskflowTaskOp task_op = builder.create<TaskflowTaskOp>(
loc,
/*memory_outputs=*/memory_output_types,
/*read_outputs=*/read_output_types,
/*write_outputs=*/memory_output_types,
/*value_outputs=*/value_output_types,
/*read_inputs=*/read_inputs,
/*write_inputs=*/write_inputs,
Expand Down Expand Up @@ -294,13 +301,23 @@ static TaskflowTaskOp convertLoopToTask(
// Step 8: Creates the yield operation.
//---------------------------------------------------------------
task_builder.setInsertionPointToEnd(task_body);
SmallVector<Value> memory_yield_operands;
SmallVector<Value> yield_for_dependency_read_out;
SmallVector<Value> yield_for_dependency_write_out;
SmallVector<Value> value_yield_operands;

// Read yield outputs: passthrough read memref block args for WAR tracking.
for (Value memref : read_memrefs) {
if (input_to_block_arg.count(memref)) {
yield_for_dependency_read_out.push_back(input_to_block_arg[memref]);
} else {
assert(false && "Read memref not in inputs!");
}
}

// Memory yield outputs: yield the written memrefs.
for (Value memref : output_memrefs) {
if (input_to_block_arg.count(memref)) {
memory_yield_operands.push_back(input_to_block_arg[memref]);
yield_for_dependency_write_out.push_back(input_to_block_arg[memref]);
} else {
assert(false && "Written memref not in inputs!");
}
Expand All @@ -310,16 +327,27 @@ static TaskflowTaskOp convertLoopToTask(
for (Value result : cloned_loop->getResults()) {
value_yield_operands.push_back(result);
}
task_builder.create<TaskflowYieldOp>(loc, memory_yield_operands,
task_builder.create<TaskflowYieldOp>(loc, yield_for_dependency_read_out,
yield_for_dependency_write_out,
value_yield_operands);

//-------------------------------------------------------------------
// Step 9 : Updates value mapping with task outputs for subsequent tasks
// conversion.
//-------------------------------------------------------------------
// Memory outputs.
// Read outputs: establishes WAR dependency chain.
// Only update mapping for memrefs not already mapped by a prior write.
for (auto [memref, task_read_output] :
llvm::zip(read_memrefs, task_op.getDependencyReadOut())) {
if (!value_mapping.count(memref)) {
value_mapping[memref] = task_read_output;
}
}

// Memory outputs (write): establishes RAW/WAW dependency chain.
// Write outputs always overwrite read outputs in the mapping.
for (auto [memref, task_output] :
llvm::zip(output_memrefs, task_op.getWriteOutputs())) {
llvm::zip(output_memrefs, task_op.getDependencyWriteOut())) {
value_mapping[memref] = task_output;
}

Expand Down
78 changes: 54 additions & 24 deletions lib/TaskflowDialect/TaskflowOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,20 @@ ParseResult TaskflowTaskOp::parse(OpAsmParser &parser, OperationState &result) {
result.addAttribute("task_name", task_name);
}

// Parses read_memrefs: read_memrefs(%arg0, %arg1 : memref<?xi32>,
// Parses dependency_read_in: dependency_read_in(%arg0, %arg1 : memref<?xi32>,
// memref<?xi32>).
SmallVector<OpAsmParser::UnresolvedOperand> read_operands;
SmallVector<Type> read_types;
if (succeeded(parser.parseOptionalKeyword("read_memrefs"))) {
if (succeeded(parser.parseOptionalKeyword("dependency_read_in"))) {
if (parser.parseLParen() || parser.parseOperandList(read_operands) ||
parser.parseColonTypeList(read_types) || parser.parseRParen())
return failure();
}

// Parses write_memrefs: write_memrefs(%arg5 : memref<?xi32>).
// Parses dependency_write_in: dependency_write_in(%arg5 : memref<?xi32>).
SmallVector<OpAsmParser::UnresolvedOperand> write_operands;
SmallVector<Type> write_types;
if (succeeded(parser.parseOptionalKeyword("write_memrefs"))) {
if (succeeded(parser.parseOptionalKeyword("dependency_write_in"))) {
if (parser.parseLParen() || parser.parseOperandList(write_operands) ||
parser.parseColonTypeList(write_types) || parser.parseRParen())
return failure();
Expand Down Expand Up @@ -132,6 +132,9 @@ ParseResult TaskflowTaskOp::parse(OpAsmParser &parser, OperationState &result) {
static_cast<int32_t>(original_write_operands.size())}));

// Adds result segment sizes.
// dependency_read_out count matches dependency_read_in count (WAR dependency
// tracking).
size_t num_read_outputs = read_operands.size();
size_t num_write_outputs = 0;
size_t num_value_outputs = 0;
for (Type t : func_type.getResults()) {
Expand All @@ -140,9 +143,13 @@ ParseResult TaskflowTaskOp::parse(OpAsmParser &parser, OperationState &result) {
else
num_value_outputs++;
}
// Total memref results include both dependency_read_out and
// dependency_write_out.
num_write_outputs = num_write_outputs - num_read_outputs;
result.addAttribute("resultSegmentSizes",
parser.getBuilder().getDenseI32ArrayAttr(
{static_cast<int32_t>(num_write_outputs),
{static_cast<int32_t>(num_read_outputs),
static_cast<int32_t>(num_write_outputs),
static_cast<int32_t>(num_value_outputs)}));

return success();
Expand All @@ -152,21 +159,21 @@ void TaskflowTaskOp::print(OpAsmPrinter &printer) {
// Prints task name.
printer << " @" << getTaskName();

// Prints read_memrefs.
if (!getReadMemrefs().empty()) {
printer << " read_memrefs(";
llvm::interleaveComma(getReadMemrefs(), printer);
// Prints dependency_read_in.
if (!getDependencyReadIn().empty()) {
printer << " dependency_read_in(";
llvm::interleaveComma(getDependencyReadIn(), printer);
printer << " : ";
llvm::interleaveComma(getReadMemrefs().getTypes(), printer);
llvm::interleaveComma(getDependencyReadIn().getTypes(), printer);
printer << ")";
}

// Prints write_memrefs.
if (!getWriteMemrefs().empty()) {
printer << " write_memrefs(";
llvm::interleaveComma(getWriteMemrefs(), printer);
// Prints dependency_write_in.
if (!getDependencyWriteIn().empty()) {
printer << " dependency_write_in(";
llvm::interleaveComma(getDependencyWriteIn(), printer);
printer << " : ";
llvm::interleaveComma(getWriteMemrefs().getTypes(), printer);
llvm::interleaveComma(getDependencyWriteIn().getTypes(), printer);
printer << ")";
}

Expand Down Expand Up @@ -213,14 +220,17 @@ void TaskflowTaskOp::print(OpAsmPrinter &printer) {

// Prints function type.
printer << " : (";
llvm::interleaveComma(llvm::concat<const Type>(getReadMemrefs().getTypes(),
getWriteMemrefs().getTypes(),
getValueInputs().getTypes()),
printer);
llvm::interleaveComma(
llvm::concat<const Type>(getDependencyReadIn().getTypes(),
getDependencyWriteIn().getTypes(),
getValueInputs().getTypes()),
printer);
printer << ") -> (";
llvm::interleaveComma(llvm::concat<const Type>(getWriteOutputs().getTypes(),
getValueOutputs().getTypes()),
printer);
llvm::interleaveComma(
llvm::concat<const Type>(getDependencyReadOut().getTypes(),
getDependencyWriteOut().getTypes(),
getValueOutputs().getTypes()),
printer);
printer << ")";

// Prints region.
Expand All @@ -234,11 +244,20 @@ void TaskflowTaskOp::print(OpAsmPrinter &printer) {

ParseResult TaskflowYieldOp::parse(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand> read_operands;
SmallVector<Type> read_types;
SmallVector<OpAsmParser::UnresolvedOperand> write_operands;
SmallVector<Type> write_types;
SmallVector<OpAsmParser::UnresolvedOperand> value_operands;
SmallVector<Type> value_types;

// Parses reads (WAR dependency passthrough).
if (succeeded(parser.parseOptionalKeyword("reads"))) {
if (parser.parseLParen() || parser.parseOperandList(read_operands) ||
parser.parseColonTypeList(read_types) || parser.parseRParen())
return failure();
}

// Parses writes.
if (succeeded(parser.parseOptionalKeyword("writes"))) {
if (parser.parseLParen() || parser.parseOperandList(write_operands) ||
Expand All @@ -253,21 +272,32 @@ ParseResult TaskflowYieldOp::parse(OpAsmParser &parser,
return failure();
}

if (parser.resolveOperands(write_operands, write_types,
if (parser.resolveOperands(read_operands, read_types,
parser.getCurrentLocation(), result.operands) ||
parser.resolveOperands(write_operands, write_types,
parser.getCurrentLocation(), result.operands) ||
parser.resolveOperands(value_operands, value_types,
parser.getCurrentLocation(), result.operands))
return failure();

result.addAttribute("operandSegmentSizes",
parser.getBuilder().getDenseI32ArrayAttr(
{static_cast<int32_t>(write_operands.size()),
{static_cast<int32_t>(read_operands.size()),
static_cast<int32_t>(write_operands.size()),
static_cast<int32_t>(value_operands.size())}));

return success();
}

void TaskflowYieldOp::print(OpAsmPrinter &printer) {
if (!getReadResults().empty()) {
printer << " reads(";
llvm::interleaveComma(getReadResults(), printer);
printer << " : ";
llvm::interleaveComma(getReadResults().getTypes(), printer);
printer << ")";
}

if (!getMemoryResults().empty()) {
printer << " writes(";
llvm::interleaveComma(getMemoryResults(), printer);
Expand Down
Loading