Skip to content

Return default ReductionParams when input has zero elements#269

Closed
jacobhinkle wants to merge 22 commits intoNVIDIA:mainfrom
jacobhinkle:reduce_zero_elt
Closed

Return default ReductionParams when input has zero elements#269
jacobhinkle wants to merge 22 commits intoNVIDIA:mainfrom
jacobhinkle:reduce_zero_elt

Conversation

@jacobhinkle
Copy link
Collaborator

Fixes #264

This just adds a short-circuit to {inner,outer}ReductionHeuristic that determines whether the input tensor to the reduction has any size-0 dimensions. If so, we return the default ReductionParams early so that we can safely assume numel>0 for the remainder of the heuristic.

i.e. if total_reduction_numel OR total_iteration_numel is zero.

In these cases we just return the default ReductionParams, which will
launch a single block.
@jacobhinkle jacobhinkle requested a review from naoyam May 2, 2023 18:23
@naoyam
Copy link
Collaborator

naoyam commented May 2, 2023

I think that if size-0 dimensions appear as non-reduction domains, the no-op scheduler should be used, so the only case that could fall in this condition is the non-reduction domains are all non-zero and there's a size-0 reduction dimension. Or, total_reduction_numel is zero, while total_iteration_numel is non-zero. Am I correct?

jacobhinkle and others added 2 commits May 3, 2023 12:55
@jacobhinkle
Copy link
Collaborator Author

!build

@jacobhinkle
Copy link
Collaborator Author

I think that if size-0 dimensions appear as non-reduction domains, the no-op scheduler should be used, so the only case that could fall in this condition is the non-reduction domains are all non-zero and there's a size-0 reduction dimension. Or, total_reduction_numel is zero, while total_iteration_numel is non-zero. Am I correct?

You're right, the NoOp scheduler will pick this up if the only reductions have zero concrete elements. Reduction scheduler could still run if there are multiple reductions, or as you say, when all the non-reduction extents are non-zero. I am testing both cases here, and I checked that we actually return the starting value when we have all non-zero concrete dimensions. This is how torch's sum works, but note that the various reductions in torch each work a bit different from one another:

[nav] In [1]: import torch

[ins] In [2]: x = torch.randn(3, 0, 2)

[ins] In [3]: x.sum(1)
Out[3]: 
tensor([[0., 0.],
        [0., 0.],
        [0., 0.]])

[ins] In [4]: x.prod(1)
Out[4]: 
tensor([[1., 1.],
        [1., 1.],
        [1., 1.]])

[ins] In [5]: x.amax(1)
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[5], line 1
----> 1 x.amax(1)

IndexError: amax(): Expected reduction dim 1 to have non-zero size.

[ins] In [6]: x.sum([0,1])
Out[6]: tensor([0., 0.])

[ins] In [7]: x.prod([0,1])
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[7], line 1
----> 1 x.prod([0,1])

TypeError: prod() received an invalid combination of arguments - got (list), but expected one of:
 * (*, torch.dtype dtype)
 * (int dim, bool keepdim, *, torch.dtype dtype)
 * (name dim, bool keepdim, *, torch.dtype dtype)


[ins] In [8]: x.amax([0, 2])
Out[8]: tensor([])

Some accept lists of dims, others (prod) accept only single dimensions. Some don't require numel>0 in those dimensions (sum, prod) and others do (amax). Anyway, I believe NVFuser correctly returns the starting value if the number of reduced elements is zero, and I'm somewhat confident now that the reduction scheduler won't balk at any of the scenarios we've laid out so far.


const int64_t n_elems = total_reduction_numel * total_iteration_numel;

if (n_elems == 0) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of having this condition in both innerReductionHeuristic and outerReductionHeuristic, can we have just one in reductionHeuristic?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this still happen? Don't the pointwise and noop schedulers take care of all these conditions?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the same question.

// Replace each reduction with a call to full(). Assumes we have already
// verified that this is safe to do.
for (auto rop : ir_utils::getReductionOps(fusion)) {
ir_utils::replaceReductionWithFull(rop);
Copy link
Collaborator Author

@jacobhinkle jacobhinkle May 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@naoyam , as we discussed this replaces reductions with full (after verifying that this is safe during canScheduleRuntime). The problem is that this replacement breaks the path from outputs to inputs. In a simple linear kernel such as in the test, schedulePointwise fails to find a reference tensor. In this kind of case, we actually have an unused input after the transformation, but I don't think there's a convenient way to eliminate that upstream branch from the replaced expression. Perhaps this could be done by introducing a new OptInMutator that would remove upstream branches (if this doesn't already exist), since this would clear the uses() for any inputs that would have only dead branches, so the reference tensor would be cleared.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds strange to me that an unused input would the reference detection fail. Unused inputs would happen without this translation, and I think they would be just ignored, no?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It becomes unused after we break the fusion at the reduction. However, the uses() of the input is still non-empty because that upstream branch is not cleared. Same happens if I use ir_utils::replaceValue.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OptOutMutator only does the replacements/removals we register, so we need to do another pass to remove the upstream branch.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is no longer an issue, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be.

out_shape[i] = out_root[i]->extent();
}
auto new_out = full(out_shape, rop->init(), rop->out()->dtype());
ir_utils::replaceValInExpr(expr, rop->out(), new_out);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want to use expr anymore as it's a reduction op.

Instead, try this:

auto rop_out = rop->out();
auto out_uses = rop_out->uses();
for (auto out_use: out_uses) {
  ir_utils::replaceValInExpr(out_use, rop_out, new_out);
}
// At this point, rop_out is no longer used by any expr. 
rop->container()->removeVal(rop);

I believe this should remove the original reduction expr and its output from the fusion. It should also trigger updating the usage map, so the input should become unused.

Let me know if this works.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah you're right. I pushed a recursive version that removes the branch. After that the test passes. I am now extending this to Welford in which case instead of full we have for example inAvg / inN, which may be nan, matching pytorch.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just want to understand, is this recursive removal necessary?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's necessary, since there could be ops before the reduction. For example, the test currently reads:

  auto tv0 = makeSymbolicTensor(2);
  fusion->addInput(tv0);
  auto tv2 = sum(tv0, {1});
  fusion->addOutput(tv2);

Removing the sum makes it so that tv0->uses() is empty so that it doesn't complain here: https://github.com/NVIDIA/Fuser/blob/main/csrc/scheduler/pointwise_utils.cpp#L109

If we have another op before the reduction:

  auto tv0 = makeSymbolicTensor(2);
  fusion->addInput(tv0);
  auto tv1 = neg(tv0);
  auto tv2 = sum(tv1, {1});
  fusion->addOutput(tv2);

and we just remove the sum, tv0 would still have a use, so it would block tv2 being a valid reference tensor.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't be. If the sum is removed, the fusion would look like:

  auto tv0 = makeSymbolicTensor(2);
  fusion->addInput(tv0);
  auto tv1 = neg(tv0);
  auto tv3 = full(...);
  fusion->addOutput(tv3);

Then, the use map should be updated and the neg should no longer appear as a use of tv0 since it's not connected with the output. I think its allocated objects are still kept in the fusion container, so removing them is helpful, but shouldn't be necessary.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your right that resetTvUses() updates uses for us, as you say. I removed the prune function. I am now trying to fix another issue that came up when I added more ops: finding reference tensor looks at all input IterDomains to see that they map to the reference, but since we accept reductions now those reduced IterDomains in the input are not mapped to the reference. So I'm updating DomainMap to make exceptions for reduced domains in the input.

jacobhinkle added 11 commits May 4, 2023 15:03
pruneProducerBranches starts at given values and moves upstream removing
Exprs and Vals with no outside outputs/uses.
This is not necessary, since resetTvUses will automatically be run when
needed. Still, the updated test including unary ops before and after the
replaced reduction is still failing.
NoOp cases are currently failing.
Tests still fail. It's difficult to satisfy NoOp by replacement with
full, unless we also concretize zero-size extents in such a way that
size-zero extents are constant. That would mean moving the check back to
compile time for NoOp, which would be simpler. It would also mean all
Fusions with reductions would be dynamic, as originally discussed
between myself and @naoyam. This would mean adding a concretization step
for each new set of input sizes, but that can be made faster as well,
e.g. NVIDIA#244.

namespace nvfuser {

class SchedulerRuntimeInfo;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ir_utils requiring SchedulerRuntimeInfo seems a bit odd to me.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I currently only use it for the ExpressionEvaluator, so that we can tell which extents are actually zero.

TORCH_CUDA_CU_API bool isReductionTvOp(const Expr*);

//! Returns if Expr is a reduction over at least one size-0 dimension
TORCH_CUDA_CU_API bool isReductionOverSizeZero(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this should be in scheduler/reduction_utils.h?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, schedulers commonly use runtime info and we should avoid that leaking out of the schedulers dir.

All of these functions seem to be used around scheduling/segmentation. I don't think it makes sense to have them hang out here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we go with the approach in this PR, I will move them. Thanks for the context: I will try and keep most runtime info corralled in scheduler/.


//! Given a set of Vals, recursively remove them until we find an input or an
//! expression with outputs that are used outside the provided set.
TORCH_CUDA_CU_API void pruneProducerBranches(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you still have this function defined?

}
auto expr_eval = runtime_info.expressionEvaluator();
for (auto output : ir_utils::filterByType<TensorView>(expr->outputs())) {
for (auto id : output->getRootDomain()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be getMaybeRfactorDomain

}
auto expr_eval = runtime_info.expressionEvaluator();
for (auto output : ir_utils::filterByType<TensorView>(expr->outputs())) {
for (auto id : output->getRootDomain()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here


const int64_t n_elems = total_reduction_numel * total_iteration_numel;

if (n_elems == 0) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this still happen? Don't the pointwise and noop schedulers take care of all these conditions?

return !reason.empty();
}

// Reject if there is a non-trivial reduction
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These checks get run by FusionKernelRuntime when we have previously accepted a Fusion and need to check it against new inputs. If we first run with size-zero reductions but then pass inputs that make those reductions non-trivial, we need Pointwise and NoOp to reject. Same goes for the other direction. All these checks are avoided if we handle reductions at concretization instead since we can then assume the reduction is nontrivial in both concrete and reduction domains.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sorry, I thought this was added the transpose scheduler since I saw TransposeScheduler at around line 1735.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we made size==0 special like size==1 then we wouldn't need this check.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's an interesting idea. I have been picturing size==0 as Iteration, but it does have this slightly different implication as its presence in a TensorView implies that TV's definition is a noop.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When a TV has a SizeZero domain, it is indicating that its definition() need not be processed. We can replace the TV with a constant empty tensor, essentially. If all its uses() lead to outputs with SizeZero IDs, then the TV and all the outputs don't need to be processed by the Fusion at all; they can be DCE'd. Reductions preserve SizeZero segments though since if all the SizeZero IDs appear in the reduction axes of the op, the output is equivalent to full.

Also, see #340 which makes the case for using Broadcast to indicate more than extent == 1, namely that the broadcast also is resolved at some point. Likewise, iter_type == SizeZero has a relationship to reductions as stated above, where reducing over SizeZero "resolves" the size-0 property by implying we can replace the reduction with the reduction's initial value (or nan for Welford).

// Purely symbolic extents _might_ be zero depending on inputs.
// These will need to be checked at runtime. So we should only
// reject if the extent is known to be non-zero at compile time.
return !id->extent()->isConstInt() || id->extent()->isZeroInt();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this check be also done at the run time? The compile-time check only makes sure there may be a size-0 non-reduction domain, but it's not guaranteed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it needs a runtime check too.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should treat zero size as a special type like we do with one size so if a size changes in/out of zero we always recompile. cc @jjsjann123

// Reject if there are any zero-size reductions, since these would prefer
// pointwise, or if there are reductions with zero-size concrete
// dimensions, since these would prefer NoOp
if (ir_utils::isReductionOverSizeZero(rop, runtime_info) ||
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would happen if there are two reduction domains and only one of them is size-0? It doesn't seem like a full op as there's still a non-size-0 reduction domain, but isReductionOverSizeZero would return true.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding may be flawed, but in that case there are still zero elements being reduced aren't there? For example we could merge all the reduction IterDomains and would wind up with an extent-zero reduction domain.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with this interpretation.

}
}

auto reduction_ops = ir_utils::getReductionOps(fusion);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where's the corresponding check of the eligibility of using the pointwise scheduler with reductions?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYM?

eraseIfInputMappedThroughRFactorDomain(
in_concrete_ids, tv->getMaybeRFactorDomain());

// Erase input concrete IDs mapped to any reduction domains
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can just ignore any reduction domains as they should not show up in the fusion after reduction-to-full translations.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could, but this will check the root IterDomains of each input TV to check that they map to something in the reference tensor (output TV). If the replaced reduction was not the first or last op in the Fusion, then the reduction domain will be hidden so we have to check if it is mapped.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't you simply ignore any input tensor view that doesn't have any uses()?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function needs to check that all input IDs are represented by the reference tensor. Any input with empty uses() would violate that since its IDs would not map to any others. You're right that this is safe, and is already in place at the one call-site:

for (auto input_tv : ir_utils::filterByType<TensorView>(fusion_->inputs())) {
if (input_tv->uses().empty()) {
continue;
}
. Note that more generally, we could use InputsOf(fusion->outputs()) so that empty uses propagate back to give us more ignored inputs.

The check here is actually skipping any input IDs that map to internal reductions (not represented in outputs. That seemed like a good idea at the time but now I realize that's a bad idea, since we do still need those IDs represented in the reference tensor.

Copy link
Collaborator

@csarofeen csarofeen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please summarize the expected behavior of size==0 scenarios you're supporting with this PR?

This feels like a bit strange of a WAR, but I don't see any reason it wouldn't work.

}

// Same as previous test but with Welford instead
TEST_F(NVFuserTest, FusionWelfordRFactor_CUDA) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this supposed to be here?

TORCH_CUDA_CU_API bool isReductionTvOp(const Expr*);

//! Returns if Expr is a reduction over at least one size-0 dimension
TORCH_CUDA_CU_API bool isReductionOverSizeZero(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, schedulers commonly use runtime info and we should avoid that leaking out of the schedulers dir.

All of these functions seem to be used around scheduling/segmentation. I don't think it makes sense to have them hang out here.

auto rop = expr->as<ReductionOp>();
auto old_out = rop->out()->as<TensorView>();
auto new_out = full_like(old_out, rop->init());
ir_utils::replaceValue(fusion, {{old_out, new_out}});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're saying the only difference between using ir_utils::replaceValue and your proposed loop is that the former would change rop to have an incorrect output, and the latter wouldn't touch rop.

This is because even though rop shouldn't be used since it's disconnected from the outputs, it's still not great to have an illegally defined reduction operation hanging out.

}
} else if (expr->isA<WelfordOp>()) {
// Fill Welford avg/var with nan, N with 0
auto wop = expr->as<WelfordOp>();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this branch could easily be merged with the one above just looking at if the op is a reduction and doing the same thing for all output tensorviews.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think it could. I originally had a little more logic in here that was wrong, so it was removed. Each output of a Welford is equal to the two reductions each divided by N. For size-0 reduction domain, this is 0/0. So we just need to add the division by 0 in the case of Welford, but otherwise it's the same logic as for any other reduction.

return false;
}

void replaceReductionWithFull(Expr* expr) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't you have special handling for prod formulation here? I thought that currently errors in PyTorch.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually prod is covered the same as sum: by replacing the op with full(init_val). PyTorch is consistent with this:

[ins] In [2]: torch.prod(torch.randn([3,0]))
Out[2]: tensor(1.)

std::cout << "\nRemoving reductions\n" << std::endl;
for (auto rop : fusion->exprs()) {
if (rop->isA<ReductionOp>() || rop->isA<WelfordOp>()) {
ir_utils::replaceReductionWithFull(rop);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When would this get hit? Why do we need to remove the reduction ops in a no schedule situation? Is this because of some runtime check?

// Reject if there are any zero-size reductions, since these would prefer
// pointwise, or if there are reductions with zero-size concrete
// dimensions, since these would prefer NoOp
if (ir_utils::isReductionOverSizeZero(rop, runtime_info) ||
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with this interpretation.

// dimensions, since these would prefer NoOp
if (ir_utils::isReductionOverSizeZero(rop, runtime_info) ||
ir_utils::reductionHasSizeZeroConcreteDimension(rop, runtime_info)) {
return false;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we made size==0 special like size==1 then we wouldn't need this check.

return !reason.empty();
}

// Reject if there is a non-trivial reduction
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we made size==0 special like size==1 then we wouldn't need this check.

}
}

auto reduction_ops = ir_utils::getReductionOps(fusion);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYM?

@jacobhinkle
Copy link
Collaborator Author

@csarofeen before I address your points point-by-point, note that in some side conversations, @naoyam and I have discussed an alternative to this PR wherein we mark most reduction outputs as Symbolic, then perform the replacement with full at concretization. That means none of this size-zero reduction logic needs to be present in the schedulers, since we can assume no reductions have size-zero inputs.

Going that route means a large proportion of our Fusions may become dynamic that were otherwise static, but as we see in this current PR static Fusions can also have some complicated rewriting logic during scheduling. Also note that if we can prove the extents are non-zero, that we could potentially avoid dynamic Fusions even with reductions (see #340). So I am leaning toward abandoning this PR and creating a new one for dynamic reductions, once #258 and #244 are merged.

@jacobhinkle
Copy link
Collaborator Author

Could you please summarize the expected behavior of size==0 scenarios you're supporting with this PR?

I mentioned it above in some of the responses, but basically I think size == 0 has two key implications:

  • TVs with size-0 domains do not allocations and do not need to compute their definitions
  • Reductions over size-0 domains are equivalent to a constant (the initial value of the reduction fold).

This feels like a bit strange of a WAR, but I don't see any reason it wouldn't work.

It definitely doesn't feel clean. Maybe it's because it's the shiny new toy, but I'm inclined to use the dynamic fusion machinery for this instead.

@jjsjann123
Copy link
Collaborator

jjsjann123 commented May 16, 2023

@naoyam and I have discussed an alternative to this PR wherein we mark most reduction outputs as Symbolic, then perform the replacement with full at concretization. That means none of this size-zero reduction logic needs to be present in the schedulers, since we can assume no reductions have size-zero inputs.

This sounds like a much better approach to me than trying to work around this corner case in scheduler. With concretization specialize on size-0 dimensions, it should be easier to re-writing the reduction ops before handing them to schedulers.

I wonder if we should treat zero size as a special type like we do with one size so if a size changes in/out of zero we always recompile. cc @jjsjann123

I think the specialization here can also be worked into concretization, and is likely better that way instead of polluting our existing API. I'm expecting size-0 to not change frequently across runs, so hopefully this won't trigger extended overhead and recompilation.
In the long run, the holy grail is to have concretization handling both shape specialization as well as contiguity/stride_order specialization, so we can hide those runtime info from user API when defining computation.

@jacobhinkle
Copy link
Collaborator Author

Closing in favor of #449

jacobhinkle added a commit that referenced this pull request Jul 14, 2023
A number of issues have come up when trying to process empty tensors
(i.e. ones with at least one non-reduction axis with extent of zero)
during scheduling and lowering. See: #264 #369 #269. Additionally, we
now assume extents are positive (#440). Along with #543, this PR makes
that a reality by removing all intermediate empty tensors.

This PR:
- Marks a `Fusion` as dynamic if dynamic reshapes/resizes exist or if
_any_ alive `TensorView` has a static size-zero extent or a dynamic
extent, since it might be empty. **This is means only Fusions with
nothing but concrete non-zero sizes are static now.** That is, even if
all static shapes are provided, it will be marked as a dynamic Fusion
and those `TensorView`s will be modified during concretization.
- Adds a pass done during `getConcretizationInfo()` that collects a
vector of empty tensors which are not fusion inputs. It does not
traverse their definitions, since there is nothing to compute for an
empty tensor.
- During concretization, sets the size-0 extents of identified empty
tensors to constant 0.

When encountering a new set of input sizes/scalars, we evaluate a
minimal set of `Val`s (those that appear in dynamic extents), and only
proceed with removing branches if any of these are zero. So there is a
rather quick path to re-using concretizations in the common case where
none of the extents are zero.

Even after #543, this PR does not guarantee that all tensors present in
the Fusion during scheduling have non-zero extent. It does guarantee
that any remaining empty tensors are either fusion inputs or outputs,
and that empty tensors will have constant 0 extents in any empty
dimensions. Stripping empty inputs and outputs from the Fusion could
potentially be done at segmentation, but should only be done if it does
not result in additional kernels being launched; that is left for
another PR (see #448).

Fixes #365 and fixes #264. This replaces PRs #369 and #269.

---------

Co-authored-by: Naoya Maruyama <naoyam@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Floating-point exception scheduling reduction of zero-element tensor

4 participants