Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
711 changes: 543 additions & 168 deletions csrc/dynamic_transform.cpp

Large diffs are not rendered by default.

124 changes: 103 additions & 21 deletions csrc/dynamic_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class TORCH_CUDA_CU_API DynamicTransformInitialInfo {

//! Return whether any dynamic transforms exist in the Fusion
bool hasDynamicTransforms() const {
return !dynamic_reshaped_tvs_.empty() || !dynamic_resized_ids_.empty();
return !dynamic_expr_outputs_.empty();
}

//! Return a set of scalars that are inputs or extents of input TensorViews
Expand All @@ -50,16 +50,8 @@ class TORCH_CUDA_CU_API DynamicTransformInitialInfo {
return root_dynamic_vals_;
}

//! Return a vector of outputs of ViewOp expressions that have dynamic output
//! shapes
const std::vector<TensorView*>& getDynamicReshapedTensorViews() const {
return dynamic_reshaped_tvs_;
}

//! Return a vector of outputs of Resize expressions that have symbolic output
//! IterTypes
const std::vector<IterDomain*>& getDynamicResizedIterDomains() const {
return dynamic_resized_ids_;
const std::vector<Val*>& getDynamicExprOutputs() const {
return dynamic_expr_outputs_;
}

std::string toString() const;
Expand Down Expand Up @@ -89,16 +81,85 @@ class TORCH_CUDA_CU_API DynamicTransformInitialInfo {
// definitions will merely be altered. When the ops are replaced, if we had
// referred to them directly here, we would run into segfaults. Referring only
// to the outputs avoids this issue.
std::vector<TensorView*> dynamic_reshaped_tvs_;

std::vector<IterDomain*> dynamic_resized_ids_;
// std::vector<TensorView*> dynamic_reshaped_tvs_;

// std::vector<IterDomain*> dynamic_resized_ids_;

// Slice operations can have complicated output extents. The inputs to slice
// are a start, stop, and step for each sliced dimension. Each of these is an
// integer, and any combination of three finite integers with step != 0 is
// acceptable and should run without error. Normalization of the start and
// stop values must be done, followed by computation of the output extent:
//
// normed_start = min(max(where(start < 0, extent + start, start), 0),
// extent); normed_stop = max(min(max(where(stop < 0, extent + stop, stop),
// 0), extent), normed_start); extent = max((normed_stop - normed_start + 1)
// / step, 0);
//
// These expressions are unwieldy and cannot be significantly simplified
// unless we know certain relations about the start, stop, and step scalars.
// Here we keep track of non-static slices or slices with non-static input
// extents. That way we can restrict to a single branch in each of these
// expressions during concretization.
// std::vector<TensorView*> dynamic_sliced_tvs_;

// This is a topologically sorted list of outputs of dynamic operations.
std::vector<Val*> dynamic_expr_outputs_;

// Root Vals that determine concretization
std::unordered_set<Val*> root_dynamic_vals_;

friend class DynamicTransformInitialInfoBuilder;
};

//! This enum describes cases that can occur for the start or stop arguments to
//! slice(). Each of these leads to a different branch in the normalized form's
//! general expression.
enum class SliceIndexBranch {
Negative, // -extent < a < 0
Zero, // a == 0 OR a <= -extent
Positive, // 0 < a < extent
Extent // extent <= a
};

//! This enum describes the "step" argument to slice, which can be a positive or
//! negative integer (but not zero). We handle the special case of step == 1
//! separately from step > 1 since this simplifies some expressions.
enum class SliceStepBranch { Negative, One, GreaterThanOne };

//! Describes a 1D slice in terms of the start, stop, and extent values
struct Concrete1DSliceDescriptor {
//! These enums determine the form of the simplified expressions
SliceIndexBranch start_branch = SliceIndexBranch::Zero;
SliceIndexBranch stop_branch = SliceIndexBranch::Extent;
SliceStepBranch step_branch = SliceStepBranch::One;

//! True if normalized values satisfy (stop - start) * step <= 0 in which case
//! we would return an empty tensor.
bool is_empty = false;

//! This can be either Iteration or Broadcast (if sliced extent is 1)
IterType iter_type = IterType::Iteration;

bool operator==(const Concrete1DSliceDescriptor& other) const {
return start_branch == other.start_branch &&
stop_branch == other.stop_branch && step_branch == other.step_branch &&
is_empty == other.is_empty && iter_type == other.iter_type;
}
bool operator!=(const Concrete1DSliceDescriptor& other) const {
return !operator==(other);
}

size_t hash() const {
size_t h = (size_t)start_branch;
hashCombine(h, (size_t)stop_branch);
hashCombine(h, (size_t)step_branch);
hashCombine(h, (size_t)is_empty);
hashCombine(h, (size_t)iter_type);
return h;
}
};

//! A set of transformations for a symbolic fusion with concrete sizes
//! of the fusion inputs
class TORCH_CUDA_CU_API DynamicTransformConcretizationInfo {
Expand All @@ -115,9 +176,7 @@ class TORCH_CUDA_CU_API DynamicTransformConcretizationInfo {
// evaluator when any one of the IDs has a known value
expr_eval->propagateBoundValuesThroughExactMaps(initial_info->fusion());

analyzeReshapes(expr_eval);

analyzeResizes(expr_eval);
analyze(expr_eval);
}

//! Return a vector of pairs holding the index of each reshaped TensorView in
Expand All @@ -136,6 +195,15 @@ class TORCH_CUDA_CU_API DynamicTransformConcretizationInfo {
return resize_itertypes_;
}

//! Return a vector of pairs holding the index of each sliced TensorView in
//! the vector returned by initialInfo()->getDynamicSlicedTensorViews(),
//! along with a vector of descriptors indicating how each axis should be
//! concretized.
const std::vector<std::pair<size_t, std::vector<Concrete1DSliceDescriptor>>>&
getSliceDescriptors() const {
return slice_descriptors_;
}

//! Comparison operator for the purposes of determining cache hits. This does
//! not guarantee equality of all members. Instead, it returns equal if the
//! resulting concretizations would be structurally equivalent. Note that
Expand All @@ -148,13 +216,21 @@ class TORCH_CUDA_CU_API DynamicTransformConcretizationInfo {
}

//! Given an ExpressionEvaluator which already has input scalars bound to it,
//! determine the decomposition of each dynamic reshape operation to use
//! analyze all dynamic ops in topological order.
void analyze(ExpressionEvaluator* expr_eval);

//! Given an ExpressionEvaluator which already has input scalars bound to it,
//! determine the decomposition of a dynamic reshape operation to use
//! during concretization.
void analyzeReshapes(ExpressionEvaluator* expr_eval);
void analyzeReshape(ExpressionEvaluator* expr_eval, size_t val_index);

//! Given an ExpressionEvaluator which already has input scalars bound to it,
//! determine the branches of expressions in a dynamic slice op.
void analyzeSlice(ExpressionEvaluator* expr_eval, size_t val_index);

//! Given an ExpressionEvaluator which already has input scalars bound to it,
//! determine the concrete IterType of each resized IterDomain.
void analyzeResizes(ExpressionEvaluator* expr_eval);
//! determine the concrete IterType of a resized IterDomain.
void analyzeResize(ExpressionEvaluator* expr_eval, size_t val_index);

const DynamicTransformInitialInfo* initialInfo() const {
return initial_info_;
Expand Down Expand Up @@ -189,6 +265,12 @@ class TORCH_CUDA_CU_API DynamicTransformConcretizationInfo {
//! vector returned by initial_info_->getDynamicResizedIterDomains() along
//! with its concretized IterType
std::vector<std::pair<size_t, IterType>> resize_itertypes_;

//! Holds the index of the sliced TensorView (output of the SliceOp) in the
//! vector returned by initial_info_->getDynamicSlicedTensorViews() along
//! with a descriptor of how it should be concretized.
std::vector<std::pair<size_t, std::vector<Concrete1DSliceDescriptor>>>
slice_descriptors_;
};

class TORCH_CUDA_CU_API DynamicTransform {
Expand Down
18 changes: 18 additions & 0 deletions csrc/evaluator_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,12 @@ NaiveValueMachine::NaiveValueMachine(PrecomputedValues& precomputed_values)
makeUnaryOp(uop);
} else if (auto bop = dynamic_cast<BinaryOp*>(def)) {
makeBinaryOp(bop);
} else if (auto setop = dynamic_cast<LoadStoreOp*>(def)) {
TORCH_INTERNAL_ASSERT(
setop->opType() == LoadStoreOpType::Set,
"NaiveValueMachine: unsupported LoadStoreOpType: ",
setop->opType());
makeSetOp(setop);
} else {
TORCH_INTERNAL_ASSERT(false, "Unsupported expr");
}
Expand Down Expand Up @@ -448,6 +454,18 @@ void NaiveValueMachine::makeBinaryOp(BinaryOp* bop) {
dest_[index] = out;
}

void NaiveValueMachine::makeSetOp(LoadStoreOp* lsop) {
int in = lsop->inputs()[0]->evaluatorIndex();
int out = lsop->outputs()[0]->evaluatorIndex();
TORCH_INTERNAL_ASSERT(in >= 0, "Integer Machine: unknown input: ", lsop);
TORCH_INTERNAL_ASSERT(out >= 0, "Integer Machine: unknown out: ", lsop);

int index = makeInstructionEntry();
inst_type_[index] = InstructionType::SET_OP;
src0_[index] = in;
dest_[index] = out;
}

int NaiveValueMachine::makeInstructionEntry() {
int index = num_of_instructions_++;
inst_type_.emplace_back(InstructionType::UNARY_OP);
Expand Down
4 changes: 4 additions & 0 deletions csrc/evaluator_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class NaiveValueMachine {
//! Convert an binary IR expr to an instruction
void makeBinaryOp(BinaryOp* bop);

//! Convert a LoadStoreOp expr to an instruction. This assumes lsop->opType()
//! is equal to LoadStoreOpType::Set.
void makeSetOp(LoadStoreOp* lsop);

//! Create an empty instruction with all default values
//! and place it at the end of the instruction buffer.
int makeInstructionEntry();
Expand Down
9 changes: 8 additions & 1 deletion csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2113,7 +2113,14 @@ std::string LoadStoreOp::toString(int indent_size) const {
}

std::string LoadStoreOp::toInlineString(int indent_size) const {
TORCH_CHECK(false, "Tensor op can not be printed inline");
if (opType() == LoadStoreOpType::Set) {
TORCH_CHECK(
!in()->isA<TensorView>(), "Cannot print TensorView set() inline");
std::stringstream ss;
indent(ss, indent_size) << "set(" << in()->toInlineString() << ")";
return ss.str();
}
TORCH_CHECK(false, "Non-'Set' LoadStoreOp cannot be printed inline");
}

bool LoadStoreOp::hasTranspose() const {
Expand Down
15 changes: 14 additions & 1 deletion csrc/ir/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ struct ReplaceValInIndexVal : public OptInDispatch {
// Recursively traverse its defining expr
auto def = val->definition();
if (def != nullptr) {
if (def->isOneOf<UnaryOp, BinaryOp, TernaryOp>()) {
if (def->isOneOf<UnaryOp, BinaryOp, TernaryOp, LoadStoreOp>()) {
handle(val->definition());
} else {
TORCH_INTERNAL_ASSERT(false, "Unexpected definition: ", def->toString())
Expand All @@ -665,6 +665,19 @@ struct ReplaceValInIndexVal : public OptInDispatch {
}
}

// Clone expression after recurisvely replacing inputs
void handle(LoadStoreOp* lsop) override {
handle(lsop->in());
auto inp = last_visited_val_;
TORCH_INTERNAL_ASSERT(
lsop->out()->isA<Int>() || lsop->out()->isA<Bool>(),
"Unknown output type for expr ",
lsop->toInlineString());
auto out = IrBuilder::create<Int>(c10::nullopt);
IrBuilder::create<LoadStoreOp>(lsop->opType(), out, inp);
last_visited_val_ = out;
}

// Clone expression after recurisvely replacing inputs
void handle(UnaryOp* uop) override {
handle(uop->in());
Expand Down
4 changes: 4 additions & 0 deletions csrc/kernel_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,10 @@ InputsIdLookup::IdLookupReturn InputsIdLookup::lookupId(
FusionExecutorCache::FusionExecutorCache(std::unique_ptr<Fusion> fusion)
: fusion_(std::move(fusion)) {}

Fusion* FusionExecutorCache::getMostRecentConcretizedFusion() const {
return most_recent_runtime_->fusionSegments()->completeFusion();
}

KernelArgumentHolder FusionExecutorCache::prepareInputs(
const at::ArrayRef<c10::IValue>& inputs,
std::optional<int8_t> selected_device) {
Expand Down
2 changes: 2 additions & 0 deletions csrc/kernel_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,8 @@ class TORCH_CUDA_CU_API FusionExecutorCache {
fusion_->printMath();
}

Fusion* getMostRecentConcretizedFusion() const;

FusionKernelRuntime* getMostRecentKernelRuntime() const {
return most_recent_runtime_;
}
Expand Down
69 changes: 58 additions & 11 deletions csrc/ops/alias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,10 +639,14 @@ TensorView* cat(const std::vector<TensorView*>& inputs, int64_t cat_dim) {
return out;
}

// Currently there's no error check about the actual values of the
// Slice parameters. For example, the start parameter of a range of a
// domain is assumed to be >= 0 and < the extent of the domain.
TensorView* slice(TensorView* inp, const std::vector<Slice>& ranges) {
// If skip_symbolic is true, then the start and stop parameters of a range of a
// domain is assumed to be >= 0 and < the extent of the domain. Otherwise,
// non-constant inputs will lead to Symbolic IterDomains in the output, which
// must be later concretized.
TensorView* slice(
TensorView* inp,
const std::vector<Slice>& ranges,
bool skip_symbolic) {
const auto inp_dom = TensorDomain::noReductions(inp->getMaybeRFactorDomain());
const int ndims = static_cast<int>(inp_dom.size());

Expand All @@ -666,6 +670,19 @@ TensorView* slice(TensorView* inp, const std::vector<Slice>& ranges) {
return range;
};

// Adjust an integer value relative to a given extent. This is
// min(max(where(a < 0, extent + a, a), 0), extent)
auto adjust_start_stop = [](int64_t& a, int64_t extent) {
if (a < 0) {
a += extent;
}
if (a < 0) {
a = 0;
} else if (a > extent) {
a = extent;
}
};

for (auto& range : ranges) {
// Step not supported yet
TORCH_CHECK(
Expand All @@ -681,23 +698,53 @@ TensorView* slice(TensorView* inp, const std::vector<Slice>& ranges) {
bool needs_real_slicing = false;
for (const auto idx : c10::irange(ndims)) {
auto inp_root_id = inp_dom[idx];
auto range = normalize_slice_range(ranges.at(idx), inp_root_id->extent());
auto inp_extent = inp_root_id->getMaybeExpandedExtent();
auto range = normalize_slice_range(ranges.at(idx), inp_extent);
normalized_ranges.at(idx) = range;
IterDomain* out_root_id = nullptr;
IterDomain* out_rf_id = nullptr;
if (range.start->isZeroInt() && range.stop->sameAs(inp_root_id->extent()) &&
if (range.start->isZeroInt() && range.stop->sameAs(inp_extent) &&
range.step->isOneInt()) {
// This dim doesn't need slicing
out_root_id = inp_root_id->cloneWithoutRFactor();
out_rf_id = out_root_id;
} else {
out_root_id =
IterDomainBuilder(inp_root_id).is_rfactor_domain(true).build();
out_rf_id = IterDomain::resize(
out_root_id,
SimplifyingIrBuilder::negExpr(range.start),
sub(range.stop, inp_root_id->extent()),
true);
// The start, stop, and extent of the output will all require complicated
// expressions which will be simplified at concretization. Here we set
// the output to Symbolic unless all required scalars are constant.
if (range.start->isConstInt() && range.stop->isConstInt() &&
inp_extent->isConstInt()) {
auto start = range.start->evaluateInt();
auto stop = range.stop->evaluateInt();
auto step = range.step->evaluateInt();
TORCH_INTERNAL_ASSERT(step != 0, "Slice step must be non-zero");
TORCH_INTERNAL_ASSERT(
step == 1, "Slicing with step != 1 is not currently supported");
auto inp_extent_val = inp_extent->evaluateInt();
adjust_start_stop(start, inp_extent_val);
adjust_start_stop(stop, inp_extent_val);
out_rf_id = IterDomain::resize(
out_root_id,
SimplifyingIrBuilder::negExpr(IrBuilder::create<Int>(start)),
sub(IrBuilder::create<Int>(stop), inp_extent),
true);
} else if (skip_symbolic) {
out_rf_id = IterDomain::resize(
out_root_id,
SimplifyingIrBuilder::negExpr(range.start),
sub(range.stop, inp_extent),
true,
IterType::Iteration);
} else {
out_rf_id = IterDomainBuilder(
FusionGuard::getCurFusion()->zeroVal(),
IrBuilder::create<Int>())
.is_rfactor_domain(true)
.iter_type(IterType::Symbolic)
.build();
}
needs_real_slicing = true;
}
root_ids.at(idx) = out_root_id;
Expand Down
Loading