Skip to content

clean normalization_inner#923

Merged
liqiangxl merged 8 commits intomainfrom
llu/clean_inner_norm
Sep 23, 2023
Merged

clean normalization_inner#923
liqiangxl merged 8 commits intomainfrom
llu/clean_inner_norm

Conversation

@liqiangxl
Copy link
Collaborator

@liqiangxl liqiangxl commented Sep 21, 2023

(1) Transformed private static functions in the InnerPersistentKernelScheduler class into utility functions within an anonymous namespace.
(2) Removed checks/calculations that are unrelated to inner persistent scheduler.

bool checkReductionPattern(
Fusion* fusion,
const std::vector<TensorView*>& reduction_tvs) {
// Use root domain map to check the reduction ops have the same axes
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(1) Transformed from a static function in the InnerPersistentKernelScheduler class into a utility function within an anonymous namespace.
(2) Removed checks that were unrelated to inner persistence.

Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache,
const std::vector<TensorView*>& reduction_tvs) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(1) Transformed from a static function in the InnerPersistentKernelScheduler class into a utility function within an anonymous namespace.
(2) Removed checks that were unrelated to inner persistence.


if (inner_reduction_tvs.empty() || !outer_reduction_tvs.empty()) {
if (reduction_type != reduction_scheduler_utils::ReductionType::Inner) {
scheduler_debug_utils::canScheduleRejectReason(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simplify reduction type check using utility function.

// the iter domain of the persistent reduction.
if (!properties.fastest_dim_reduction &&
!(norm_per_sm >= warp_size / 2 ||
max_multi_reduction_factor >= warp_size)) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove checks not used by inner persistent.

// If the persistence requires over half the device don't do grid
// persistence as we can't overlap the grid comms.
if (required_sm_per_norm >
scheduler_utils::safeDiv(device_multiprocessor_count, 3)) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change from 3 to 2 based on the comments.

}

bool InnerPersistentKernelScheduler::canScheduleRunTimeOuter(
Fusion* fusion,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deleted.

}

return rparams;
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

heuristics for outer and innerOuter are removed.

@liqiangxl
Copy link
Collaborator Author

!build

@liqiangxl liqiangxl marked this pull request as ready for review September 21, 2023 21:38
@liqiangxl liqiangxl requested a review from naoyam September 21, 2023 21:39
@liqiangxl
Copy link
Collaborator Author

!build

@liqiangxl
Copy link
Collaborator Author

!build

@liqiangxl liqiangxl marked this pull request as draft September 22, 2023 14:59
@liqiangxl
Copy link
Collaborator Author

!build

1 similar comment
@liqiangxl
Copy link
Collaborator Author

!build

@liqiangxl liqiangxl marked this pull request as ready for review September 22, 2023 19:43
@liqiangxl
Copy link
Collaborator Author

!build

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Just in case, please do the diff check. It's currently disabled in CI, so needs to be done manually.

@liqiangxl liqiangxl merged commit b4335f0 into main Sep 23, 2023
@liqiangxl liqiangxl deleted the llu/clean_inner_norm branch September 23, 2023 18:32
@liqiangxl
Copy link
Collaborator Author

Previous HEAD position was 8facc54 Define TensorMap for CUDA 11 (#932)
Switched to branch 'llu/clean_inner_norm'
Your branch is up to date with 'origin/llu/clean_inner_norm'.

DIFF RESULT:

No difference found
Already on 'llu/clean_inner_norm'
Your branch is up to date with 'origin/llu/clean_inner_norm'.

@naoyam
Copy link
Collaborator

naoyam commented Sep 23, 2023

Previous HEAD position was 8facc54 Define TensorMap for CUDA 11 (#932) Switched to branch 'llu/clean_inner_norm' Your branch is up to date with 'origin/llu/clean_inner_norm'.

DIFF RESULT:

No difference found Already on 'llu/clean_inner_norm' Your branch is up to date with 'origin/llu/clean_inner_norm'.

Thanks. Just please make sure to run the benchmarks as I don't think the script runs them by default.

@liqiangxl
Copy link
Collaborator Author

Previous HEAD position was 8facc54 Define TensorMap for CUDA 11 (#932) Switched to branch 'llu/clean_inner_norm' Your branch is up to date with 'origin/llu/clean_inner_norm'.
DIFF RESULT:
No difference found Already on 'llu/clean_inner_norm' Your branch is up to date with 'origin/llu/clean_inner_norm'.

Thanks. Just please make sure to run the benchmarks as I don't think the script runs them by default.

The CI runs the benchmarks. e.g. https://gitlab-master.nvidia.com/dl/pytorch/fuser-gh-mirror/-/jobs/69890090

@naoyam
Copy link
Collaborator

naoyam commented Sep 23, 2023

I think the diff check is not enabled at this moment in the CI.

@liqiangxl
Copy link
Collaborator Author

I think the diff check is not enabled at this moment in the CI.

You are right. The diff check is not enabled in CI. But CI runs the benchmark Job jit_nvfuser_bench_jitfuture_TNVF and I ran the tools/compare_codegen.sh on my local node.

@naoyam
Copy link
Collaborator

naoyam commented Sep 23, 2023

I think the diff check is not enabled at this moment in the CI.

You are right. The diff check is not enabled in CI. But CI runs the benchmark Job jit_nvfuser_bench_jitfuture_TNVF and I ran the tools/compare_codegen.sh on my local node.

What I wanted to make sure was to run the benchmarks with the diff script as it doesn't by default.

@liqiangxl
Copy link
Collaborator Author

I think the diff check is not enabled at this moment in the CI.

You are right. The diff check is not enabled in CI. But CI runs the benchmark Job jit_nvfuser_bench_jitfuture_TNVF and I ran the tools/compare_codegen.sh on my local node.

What I wanted to make sure was to run the benchmarks with the diff script as it doesn't by default.

got you! I did this check in #928 using tools/compare_codegen.sh -- build/nvfuser_bench --benchmark_filter=NvFuserScheduler --benchmark_repetitions=1 --benchmark_min_time=0. No diff found.

liqiangxl added a commit that referenced this pull request Sep 26, 2023
similar to #923
(1) Transformed private static functions in the OuterPersistentKernelScheduler class into utility functions within an
anonymous namespace.
(2) Removed checks/calculations that are unrelated to outer persistent scheduler.
liqiangxl added a commit that referenced this pull request Sep 27, 2023
similar to #923
(1) Transformed private static functions in the InnerOuterPersistentKernelScheduler class into utility functions within
an anonymous namespace.
(2) Removed checks/calculations that are unrelated to inner_outer persistent scheduler.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants