-
Notifications
You must be signed in to change notification settings - Fork 79
Expand file tree
/
Copy pathexecutor.cpp
More file actions
1725 lines (1532 loc) · 59.7 KB
/
executor.cpp
File metadata and controls
1725 lines (1532 loc) · 59.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include "runtime/executor.h"
#include <cmath>
#include <cstring>
#include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/llvm_jit_strings.h>
#include <ATen/native/cuda/jit_utils.h>
#include <c10/core/DeviceGuard.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/cuda/CUDAStream.h>
#include "base.h"
#include "codegen.h"
#include "debug.h"
#include "device_lower/analysis/bank_conflict.h"
#include "device_lower/lower2device.h"
#include "device_lower/utils.h"
#include "driver_api.h"
#include "fusion_profiler.h"
#include "global_allocator.h"
#include "host_ir/container.h"
#include "host_ir/lower_to_communication.h"
#include "instrumentation.h"
#include "ir/all_nodes.h"
#include "ir/graphviz.h"
#include "ir/utils.h"
#include "iter_visitor.h"
#include "kernel_ir.h"
#include "multidevice/execution_utils.h"
#include "multidevice/utils.h"
#include "options.h"
#include "polymorphic_value.h"
#include "runtime/allocations.h"
#include "runtime/executor_kernel_arg.h"
#include "runtime/executor_utils.h"
#include "serde/utils.h"
#include "tensor_metadata.h"
namespace nvfuser {
std::unique_ptr<PrecomputedValues>& KernelExecutor::
evaluatorPrecomputedValues() {
if (!evaluator_precomputed_values_) {
evaluator_precomputed_values_ =
std::make_unique<PrecomputedValues>(compiledKernel()->kernel());
}
return evaluator_precomputed_values_;
}
bool ExprEvalExecutor::supported(Fusion* fusion) {
FUSER_PERF_SCOPE("ExprEvalExecutor::supported");
return std::all_of(
fusion->outputs().begin(), fusion->outputs().end(), [&fusion](Val* out) {
return fusion->getOutputAlias(out).type == AllocationType::Evaluate;
});
}
void ExprEvalExecutor::compile(Fusion* fusion) {
FUSER_PERF_SCOPE("ExprEvalExecutor::compile");
if (isProfilerEnabled()) {
FusionProfiler::segment(group_id_).startCompile();
}
NVF_ERROR(
supported(fusion),
"ExprEvalExecutor does not support the Fusion provided.");
fusion_ = std::make_unique<Fusion>(*fusion);
if (isProfilerEnabled()) {
FusionProfiler::segment(group_id_).stopCompile();
}
}
bool ExprEvalExecutor::isCompiled() const {
return fusion_ != nullptr;
}
KernelArgumentHolder ExprEvalExecutor::run(
const KernelArgumentHolder& args,
KernelArgumentHolder outputs) {
FUSER_PERF_SCOPE("ExprEvalExecutor::run");
if (isProfilerEnabled()) {
NVF_CHECK(
group_id_ >= 0,
"An invalid segment id is passed to FusionProfiler!:",
group_id_);
SegmentProfiler& sprof = FusionProfiler::segment(group_id_);
sprof.inputBytesAccessed(computeBytes(args));
sprof.scheduler(toString(SchedulerType::ExprEval));
sprof.startKernel();
}
NVF_ERROR(fusion_, "Need to compile before you can run.");
// Bind fusion inputs
auto expr_eval = executor_utils::bindInputs(args, fusion_.get());
{
NVF_ERROR(
outputs.empty(),
"Fusion executor is using expression evaluator,",
" and expects that the outputs are not populated, which they were.");
if (outputs.empty()) {
for (const auto& out_val : fusion_->outputs()) {
auto out_tensor = expr_eval.evaluate(out_val).as<at::Tensor>();
expr_eval.bind(out_val, out_tensor);
outputs.push(out_tensor);
}
}
}
if (isProfilerEnabled()) {
FusionProfiler::segment(group_id_).stopKernel();
FusionProfiler::segment(group_id_).setDevice(args.getDeviceIndex());
}
return outputs;
}
namespace {
bool hasCpuScalarOutputs(Fusion* _fusion) {
if (_fusion->exprs().empty()) {
return false;
}
std::unordered_map<TensorView*, bool> tv_is_cpu_map;
for (Expr* expr : StmtSort::getExprs(_fusion)) {
bool has_cpu_scalar_input = false;
bool has_cuda_input = false;
for (Val* inp : expr->inputs()) {
if (auto* inp_tv = dynamic_cast<TensorView*>(inp)) {
if (inp_tv->isCpuScalar()) {
has_cpu_scalar_input = true;
} else {
has_cuda_input = true;
// Return early -- found atleast one CUDA input
break;
}
}
}
if (!has_cuda_input && has_cpu_scalar_input) {
// Expr is of the second category, and has all CPU scalar outputs
for (Val* out : expr->outputs()) {
if (auto* out_tv = dynamic_cast<TensorView*>(out)) {
tv_is_cpu_map[out_tv] = true;
}
}
}
}
bool has_any_cpu_output = std::any_of(
_fusion->outputs().begin(),
_fusion->outputs().end(),
[&tv_is_cpu_map](Val* out) {
return out->isA<TensorView>() && tv_is_cpu_map[out->as<TensorView>()];
});
return has_any_cpu_output;
}
} // namespace
bool KernelExecutor::supported(Fusion* fusion) {
FUSER_PERF_SCOPE("KernelExecutor::supported");
return !hasCpuScalarOutputs(fusion);
}
void KernelExecutor::compile(
Fusion* fusion,
const KernelArgumentHolder& args,
const LaunchParams& launch_constraints,
CompileParams compile_params,
SchedulerType scheduler_type) {
FUSER_PERF_SCOPE("KernelExecutor::compile");
NVF_ERROR(
supported(fusion),
"KernelExecutor does not support the Fusion provided.");
NVF_ERROR(
!fusion->outputs().empty(), "No output found for this kernel, aborting.");
c10::Device device(c10::DeviceType::CUDA, args.getDeviceIndex());
if (isProfilerEnabled()) {
NVF_CHECK(
group_id_ >= 0,
"An invalid segment id is passed to FusionProfiler!:",
group_id_);
FusionProfiler::segment(group_id_).setDevice(device.index());
FusionProfiler::segment(group_id_).startCompile();
}
//! Force index_type to int and disable magic zero if we detect that the
//! kernel contains any TMA memory operations.
std::vector<Expr*> exprs = fusion->exprs();
bool has_cp_async_bulk = std::any_of(exprs.begin(), exprs.end(), [](Expr* e) {
return ir_utils::isCpAsyncBulk(e);
});
// Disable magic zero if there are any TMA operations in Fusion
if (has_cp_async_bulk) {
compile_params.enable_magic_zero = false;
}
// Set the index type of compile params if not already set. If set,
// make sure the compile param type is valid with the given kernel
// arguments.
auto arg_index_type = args.getSmallestIndexTypeOfArguments();
if (compile_params.index_type.has_value()) {
// If the int32 compilation is requested, but the arguments demand
// int64, that's an error
NVF_ERROR(
!(compile_params.index_type.value() == PrimDataType::Int32 &&
arg_index_type == PrimDataType::Int),
"Compilation with int32 is requested but int64 is required for the "
"arguments");
} else {
// If the given compile option doesn't specify the index type, and
// the arguments require 64-bit indexing, we need to use 64-bit
// indexing. Note that if the arg type is 32-bit, it doesn't mean
// it's safe to use 32-bit for the whole kernel, so unless it's
// specified through CompileParams, we do not use 32-bit indexing.
compile_params.index_type = arg_index_type;
compile_params.index_type = arg_index_type;
}
c10::DeviceGuard dg(device);
NVF_ERROR(device.is_cuda(), "Provided device to CUDA fuser is the CPU.");
auto properties = at::cuda::getDeviceProperties(device.index());
// TODO: These properties should be set as part of the constructor so that it
// can be const
device_smem_limit_ = static_cast<int64_t>(properties->sharedMemPerBlockOptin);
warp_size_ = properties->warpSize;
// Lowered is needed to compute launch parameters as it uses the CA map. We
// could modify that, but simply generating that part first.
compiled_kernel_ = std::make_unique<CompiledKernel>(
fusion,
compile_params,
device,
scheduler_type,
fusion_id_,
concrete_id_,
runtime_id_,
group_id_,
lowering_hooks_,
post_lowering_hooks_);
// TODO: pass block_size here;
std::optional<int64_t> dynamic_smem = std::nullopt;
std::optional<int64_t> block_size = std::nullopt;
LaunchParams launch_params = launch_constraints;
if (!args.empty()) {
auto expr_eval =
executor_utils::bindInputs(args, compiled_kernel_->lowered()->kernel());
NVF_ERROR(compile_params.index_type.has_value());
launch_params = computeLaunchParams(
launch_constraints,
expr_eval,
warp_size_,
compile_params.index_type.value());
block_size = launch_params.nThreads();
dynamic_smem = launch_params.smem();
NVF_ERROR_GT(*block_size, 0);
}
// Launch parameters are required to compile the kernel to:
// (1) validate register sharing
// (2) runtime function may use static CTA shape, e.g.
// iterGroupedStaticWarpAllReduce
compiled_kernel_->compile(launch_params);
// These should be nullopt at this point, but reset just in case
resetCompiledKernelProperties();
// If the dynamic shmem size is known, make sure the compiled kernel
// has at least that size of dynamic shmem
if (dynamic_smem.has_value()) {
ensureAvailableDynamicSmemSize(dynamic_smem.value());
}
if (isProfilerEnabled()) {
FusionProfiler::segment(group_id_).stopCompile();
}
for (Expr* expr : exprs) {
if (ir_utils::isCpAsyncBulk(expr)) {
has_tma_ = true;
}
if (expr->isA<RNGOp>()) {
has_rng_ = true;
}
}
// If an output has an alias to an input and is marked Evaluate, then
// expression evaluator evaluate is called on that output to produce the meta
// data manipulation it requires. If that manipulation is something like a
// slice, and that slice has a symbolic integer it depends on, then this
// function returns true.
//
// This could happen for other examples and has_dynamic_alias_ will be true if
// to evaluate the output that has an alias, other values besides the aliased
// input need to be bound to the expression evaluator to evaluate the output.
for (TensorView* out_tv :
ir_utils::filterByType<TensorView>(fusion->outputs())) {
auto alias_info = fusion->getOutputAlias(out_tv);
if (alias_info.type != AllocationType::Evaluate) {
continue;
}
auto aliased_to = alias_info.aliased_io->as<TensorView>();
auto inputs = InputsOf::output(out_tv);
for (auto input : inputs) {
if (input->isA<TensorView>() && input->sameAs(aliased_to)) {
continue;
}
if (input->isConst()) {
continue;
}
has_dynamic_alias_ = true;
}
}
}
LaunchParams KernelExecutor::computeLaunchParams(
const LaunchParams& launch_constraints,
ExpressionEvaluator& expr_eval,
const int64_t warp_size,
DataType index_type) {
FUSER_PERF_SCOPE("KernelExecutor::computeLaunchParams");
NVF_ERROR(warp_size > 0, "WARP_SIZE should be larger than 0");
LaunchParams launch_params;
auto data_cache = compileTimeDataCache();
auto lower = compiled_kernel_->lowered().get();
if (compiled_kernel_->getUsedTVs().empty()) {
compiled_kernel_->setUsedTVs();
}
auto& used_tvs = compiled_kernel_->getUsedTVs();
auto parallel_binding_ids_entry =
executor_utils::caching::ExecutorCompileTimeEntry<
executor_utils::caching::ParallelBindingIterDomains>(
data_cache, [&used_tvs, &lower]() {
return std::make_unique<std::vector<IterDomain*>>(
executor_utils::getParallelBindingsIterDomains(
lower, used_tvs));
});
auto& parallel_binding_ids = parallel_binding_ids_entry.get();
auto parallel_iter_extent_entry =
executor_utils::caching::ExecutorCompileTimeEntry<
executor_utils::caching::ParallelIterExtentMap>(
data_cache, [¶llel_binding_ids]() {
return executor_utils::getParallelIterExtents(parallel_binding_ids);
});
auto& parallel_iter_extents = parallel_iter_extent_entry.get();
const auto& simplified_parallel_iter_extents =
lower->info().parallelDimensionMap().getMap();
// TODO: Need to redesign this part a bit to
// find the right place to trigger evaluate
if (expr_eval.precomputedValues()) {
expr_eval.precomputedValues()->bindParallelExtents(
parallel_iter_extents, launch_constraints);
expr_eval.precomputedValues()->evaluate();
}
// If any dimension was set in launch constraints we need to run through
// IterDomains that have been parallelized, and bind those values. Or make
// sure if they could be inferred the inference matches what was set.
for (auto& entry : parallel_iter_extents) {
auto p_type = entry.first;
if (launch_constraints.hasDim(p_type)) {
auto parallel_extents = entry.second;
for (auto extent : parallel_extents) {
auto inferred_val = expr_eval.evaluate(extent);
if (inferred_val.hasValue()) {
// This value could have been inferred, make sure it was set right.
bool valid =
inferred_val.as<int64_t>() == launch_constraints.getDim(p_type) ||
launch_constraints.getRawVal(p_type) == -1;
if (!useFallback() && !valid) {
TORCH_WARN_ONCE(
"Cannot validate parallelization scheme, "
"this may be due to mixed broadcast axes that are "
"parallelized.");
}
} else if (!expr_eval.precomputedValues()) {
expr_eval.bind(extent, launch_constraints.getDim(p_type));
}
if (!launch_params.hasDim(p_type)) {
// Bind the launch constraint into our evaluation context
launch_params.bind(launch_constraints.getDim(p_type), p_type);
// Makes sure the p-types bound to evaluators are the
// final values that will become the actual launch
// param size to ensure accurate smem buffer size
// computation.
expr_eval.bind(p_type, launch_constraints.getDim(p_type));
}
}
}
}
// Run through the rest of the parallel IterDomains and infer their size
for (auto [p_type, extent] : simplified_parallel_iter_extents) {
FUSER_PERF_SCOPE("KernelExecutor::ParallelBindingResolution");
auto val = expr_eval.evaluate(extent);
NVF_ERROR(
val.hasValue(),
"Tried to evaluate the extent, ",
extent->toInlineString(),
" for the ptype: ",
p_type,
" to set launch bounds but could not.");
if (val > 0) {
expr_eval.bind(p_type, val);
launch_params.bind(val.as<int64_t>(), p_type);
}
}
// Re-run the integer machine with all
// the thread sizes now determined.
if (expr_eval.precomputedValues()) {
expr_eval.precomputedValues()->evaluate();
}
const auto kernel = compiled_kernel_->lowered()->kernel();
const auto& kernel_summary = kernel->summary();
// Calculate Dynamic Shared Memory Size
// Add workspace for reduction and broadcast
int64_t reduction_broadcast_workspace = 0;
const bool has_workspace = kernel_summary.has_block_reductions ||
kernel_summary.has_grid_reductions ||
kernel_summary.has_block_broadcasts || kernel_summary.has_grid_broadcasts;
if (has_workspace &&
kernel_summary.largest_smem_data_type != DataType::Null) {
// Not using nThreads here since it does not handle uninitialized value
// TODO: here is an optimization opportunity since welford uses int64_t for
// N while the data type is not neccessarily double. But it may need more
// work on the alignment
const int welford_factor =
kernel_summary.has_block_welford || kernel_summary.has_grid_welford ? 3
: 1;
// in outer reduction, may group iteration domain, e.g. when vectorized.
const int64_t grouped_iter_factor = kernel_summary.num_grouped_iterations;
NVF_CHECK(
!(kernel_summary.has_iter_grouped_reductions && welford_factor == 3),
"can't have welford and iter grouped reductions at the same time! "
"Should be handled by grouped welford!");
// For block reduction, each thread has a smem slot per reduction
// When warp specialization is used, remove padded threads
// For warp reduction, each warp has a smem slot per reduction
int64_t n_compute_threads_or_warps = launch_params.nThreads();
if (kernel_summary.circular_buffer_info.hasWarpSpecialized()) {
n_compute_threads_or_warps -= kWarpSpecializationPaddedThreads;
}
if (kernel_summary.all_block_reductions_are_warp_reduction) {
n_compute_threads_or_warps /= 32;
}
reduction_broadcast_workspace =
dataTypeSizeByte(kernel_summary.largest_smem_data_type, index_type) *
grouped_iter_factor * welford_factor * n_compute_threads_or_warps;
if (kernel_summary.has_outer_grouped_grid_welford) {
reduction_broadcast_workspace = std::max(
reduction_broadcast_workspace,
(int64_t)kernel_summary.outer_grouped_grid_welford_largest_smem_size);
}
reduction_broadcast_workspace =
alignSharedMemoryBytes(reduction_broadcast_workspace);
if (isDebugDumpEnabled(DebugDumpOption::DynamicSharedMemory)) {
debug() << "reduction_broadcast_workspace shared memory bytes: "
<< reduction_broadcast_workspace << std::endl;
}
}
const auto dynamic_smem_size = computeSharedMemory(
expr_eval,
kernel_summary.dynamic_smem_allocations,
index_type,
reduction_broadcast_workspace);
// Check that requested smem size can be dynamically allocated.
// This check is only done once a kernel has been compiled, since
// maybe_available_dynamic_smem_ needs to be evaluated on
// a compiled kernel.
if (compiled_kernel_->isCompiled()) {
validateDynamicSmemSize(dynamic_smem_size);
}
launch_params.setSmem(dynamic_smem_size);
return launch_params;
}
std::vector<GlobalBufferInfo> KernelExecutor::getIntermediateBufferInfo(
ExpressionEvaluator& expr_eval,
DataType index_type) {
FUSER_PERF_SCOPE("KernelExecutor::getIntermediateBufferInfo");
std::vector<GlobalBufferInfo> global_buffers;
const auto kernel = compiled_kernel_->lowered()->kernel();
const auto& kernel_summary = kernel->summary();
for (auto alloc : kernel_summary.global_allocations) {
NVF_ERROR(
alloc->buffer()->isA<TensorView>(),
"Cannot allocate global buffers that are not tensors.");
auto tv = alloc->buffer()->as<TensorView>();
if (tv->isFusionOutput()) {
continue;
}
if (alloc->alias() != nullptr) {
// When aliased, no tensor argment is passed to the
// kernel. Inside the kernel, the aliasing tensor is defined as
// an alias of the aliasee, e.g., "auto& T2 = T4". The validity
// of the aliasing should be confirmed at the time of lowering.
continue;
}
GlobalBufferInfo info;
info.tv = tv;
info.zero_init = alloc->zeroInit();
info.resets_to_zero = alloc->resetsToZero();
// TODO: Allocation size needs to consider both expanded domains
// as well as halo. Currently, halo support has bene removed so we only need
// to worry about the expand case which is handled in
// inferShapeAndContiguousStrides. There used to also be a
// inferShapeOfIntermediateAndContiguousStride function before this commit,
// but that was safely removed with halo support. This will need to be
// revisited when halo support is added again.
auto [sizes, strides] = inferShapeAndContiguousStrides(tv, expr_eval);
info.shape_info.logical_sizes = sizes;
info.shape_info.logical_strides = strides;
auto dtype = tv->dtype() == DataType::Index ? index_type : tv->dtype();
info.type = data_type_to_aten(dtype);
// Remember the tensor buffer used for storing kernel profile
if (isOptionEnabled(EnableOption::KernelProfile) &&
tv == kernel->profile().getBuffer()) {
info.is_profile_buffer = true;
}
global_buffers.emplace_back(info);
}
return global_buffers;
}
namespace {
// Make sure the index type of Kernel is valid
void validateIndexType(
kir::Kernel* kernel,
const CompileParams& compile_params) {
NVF_ERROR(
!compile_params.index_type.has_value() ||
kernel->indexType() == compile_params.index_type.value(),
"Kernel index type and compilation index type don't match. Kernel type: ",
kernel->indexType(),
". Compilation index type: ",
compile_params.index_type.value());
}
void validateCooperativeLaunch(
CUfunction kernel,
const LaunchParams& launch_params,
int64_t device_index) {
int num_blocks_per_SM = -1;
auto block_size =
launch_params.bdimx() * launch_params.bdimy() * launch_params.bdimz();
NVFUSER_CUDA_SAFE_CALL(cuOccupancyMaxActiveBlocksPerMultiprocessor(
&num_blocks_per_SM,
kernel,
(int)block_size,
(size_t)launch_params.smem()));
auto grid_size =
launch_params.gdimx() * launch_params.gdimy() * launch_params.gdimz();
auto max_active_blocks = num_blocks_per_SM *
at::cuda::getDeviceProperties((c10::DeviceIndex)device_index)
->multiProcessorCount;
NVF_ERROR(
(int64_t)(max_active_blocks) >= grid_size,
"Wanted to launch a cooperative kernel, however the number of blocks is "
"greater than ",
"what can be resident on the GPU at once. Need: ",
grid_size,
" (",
launch_params.gdimx(),
" * ",
launch_params.gdimy(),
" * ",
launch_params.gdimz(),
") but limited to ",
num_blocks_per_SM,
" * ",
at::cuda::getDeviceProperties(device_index)->multiProcessorCount);
}
// Dump fusion inputs and outputs as well as some useful fusion
// information. Note that inputs and outputs are those that are passed
// to KernelExecutor::runFusion, so outputs may not be given.
void dumpFusionArgs(
int64_t fusion_id,
const KernelArgumentHolder& args,
const LaunchParams& launch_constraints,
const CompileParams& compile_params,
const KernelArgumentHolder& outputs) {
debug() << "Arguments for fusion" << fusion_id << ":" << std::endl
<< "Inputs:" << std::endl;
for (auto i : arange(args.size())) {
debug() << " " << args[i] << std::endl;
}
debug() << "Outputs:" << std::endl;
for (const auto& output : outputs) {
debug() << PolymorphicValue_functions::toString(output) << std::endl;
}
debug() << launch_constraints.toString();
debug() << "maxrregcount= " << compile_params.maxrregcount << std::endl;
}
// Dump arguments that are passed to a CUDA kernel call, which include
// the inputs and outputs of the fusion as well as temporary
// global-memory buffers. Unlike dumpFusionArgs, which dumps inputs
// and outputs passed to KernelExecutor::runFusion, this function
// dumps those that are passed to a CUDA kernel.
void dumpKernelArgs(
const int64_t fusion_id,
const int64_t group_id,
const KernelArgumentHolder& args,
size_t num_inputs,
const KernelArgumentHolder& allocated_outputs,
const KernelArgumentHolder& intermediates,
const std::vector<GlobalBufferInfo>& intermediates_info) {
using namespace PolymorphicValue_functions;
debug() << "Arguments for fusion " << fusion_id << " group " << group_id
<< ":" << std::endl
<< "Inputs:" << std::endl;
for (auto i : arange(num_inputs)) {
debug() << " " << toString(args[i]) << std::endl;
}
debug() << "Outputs:" << std::endl;
// note: add aliased outputs here.
for (const auto& output : allocated_outputs) {
debug() << " " << PolymorphicValue_functions::toString(output)
<< std::endl;
}
debug() << "Intermediate global buffers:" << std::endl;
for (const auto i : arange(intermediates.size())) {
const auto& zero_init = intermediates_info.at(i).zero_init;
const auto& resets_to_zero = intermediates_info.at(i).resets_to_zero;
debug() << " " << PolymorphicValue_functions::toString(intermediates[i])
<< " is_zero_initialized: " << zero_init
<< " resets_to_zero: " << resets_to_zero << std::endl;
}
}
} // namespace
void KernelExecutor::initializeExecutorEntry(
KernelExecutorEntry& executor_entry,
const KernelArgumentHolder& args,
const LaunchParams& launch_constraints,
const CompileParams& compile_params,
const KernelArgumentHolder& output_args,
DataType index_type) {
FUSER_PERF_SCOPE("KernelExecutor::initializeExecutorEntry");
ExpressionEvaluator expr_eval =
executor_utils::bindInputs(args, compiled_kernel_->kernel());
auto launch_params = computeLaunchParams(
launch_constraints, expr_eval, warp_size_, index_type);
for (const auto& entry : compiled_kernel_->kernel()->summary().validations) {
NVF_CHECK(expr_eval.evaluate(entry.first).as<bool>(), entry.second);
}
executor_utils::validateVectorizedTensors(
compiled_kernel_->kernel(),
args,
output_args,
compileTimeDataCache(),
expr_eval);
executor_utils::validateCircularBuffering(
compiled_kernel_->kernel(), expr_eval);
executor_utils::validateIndexCasts(
compiled_kernel_->kernel(), expr_eval, launch_params);
// Check that a full warp exists in blockDim.x if the kernel contains
// ElectSync predicate.
constexpr int64_t warp_size = 32;
NVF_ERROR(
!compiled_kernel_->kernel()->summary().has_elect_sync_predicate ||
launch_params.bdimx() >= warp_size,
"This cuda kernel contains electSync predicate. "
"Expected blockDim.x >= 32 but found ",
launch_params.bdimx());
NVF_ERROR_LE(
std::ssize(compiled_kernel_->kernel()->inputs()),
args.size(),
"`args` may contain more entries than regular inputs, e.g., the stream "
"index.");
std::vector<GlobalBufferInfo> input_info;
input_info.reserve(compiled_kernel_->kernel()->inputs().size());
for (const auto& [input, arg] :
zip(compiled_kernel_->kernel()->inputs(), args)) {
auto* input_tv = dynamic_cast<TensorView*>(input);
if (input_tv == nullptr) {
continue;
}
auto arg_tensor = arg.as<at::Tensor>();
std::vector<int64_t> alloc_sizes;
std::vector<int64_t> alloc_strides;
if (input_tv->hasAllocation()) {
std::tie(alloc_sizes, alloc_strides) =
inferAndValidateAllocationSizesAndStrides(
arg_tensor, input_tv, expr_eval);
}
TensorShapeInfo shape_info;
shape_info.logical_sizes = arg_tensor.sizes().vec();
shape_info.logical_strides = arg_tensor.strides().vec();
if (isSharded(input_tv)) {
shape_info.unsharded_logical_sizes =
unshardedSizes(input_tv, shape_info.logical_sizes);
}
shape_info.allocation_sizes = alloc_sizes;
shape_info.allocation_strides = alloc_strides;
GlobalBufferInfo info{
input_tv,
shape_info,
data_type_to_aten(input_tv->dtype()),
false,
false,
false};
input_info.emplace_back(info);
}
std::vector<GlobalBufferInfo> output_info;
if (output_args.empty()) {
output_info = getBufferInfos(
expr_eval, index_type, compiled_kernel_->kernel()->outputs());
} else {
// Need to save the information necessary for allocations as
// future uses of this KernelExecutorEntry may not be provided with
// allocated outputs
for (auto output_idx : arange(output_args.size())) {
NVF_ERROR(
output_args[output_idx].hasValue() &&
output_args[output_idx].is<at::Tensor>(),
"Output is not populated or not a Tensor");
const auto& output_tensor = output_args[output_idx].as<at::Tensor>();
GlobalBufferInfo info;
info.type = output_tensor.scalar_type();
auto out_val = compiled_kernel_->kernel()->outputs()[output_idx];
NVF_ERROR(out_val->isA<TensorView>(), "Output is not a TensorView");
info.tv = out_val->as<TensorView>();
if (info.tv->hasAllocation()) {
// Validate that the pre-allocated output tensor matches the allocation
// domain requirements
auto [alloc_sizes, alloc_strides] =
inferAndValidateAllocationSizesAndStrides(
output_tensor, info.tv, expr_eval);
info.shape_info.allocation_sizes = alloc_sizes;
info.shape_info.allocation_strides = alloc_strides;
}
info.shape_info.logical_sizes = output_tensor.sizes().vec();
info.shape_info.logical_strides = output_tensor.strides().vec();
output_info.emplace_back(info);
}
}
auto intermediates = getIntermediateBufferInfo(expr_eval, index_type);
// All information is gathered. Save it to KernelExecutorEntry
executor_entry.launch_params = launch_params;
executor_entry.outputs = output_info;
executor_entry.output_aliased_to_input =
executor_utils::getOutputAliasToInputMap(compiled_kernel_->kernel());
executor_entry.intermediates = intermediates;
executor_entry.inputs = input_info;
executor_entry.init = true;
}
namespace {
GlobalBufferInfo& linear_buffer_info_getter(
KernelExecutorEntry& entry,
size_t idx) {
if (idx < entry.inputs.size()) {
return entry.inputs[idx];
} else if (idx < entry.inputs.size() + entry.outputs.size()) {
return entry.outputs[idx - entry.inputs.size()];
} else if (
idx <
entry.inputs.size() + entry.outputs.size() + entry.intermediates.size()) {
return entry
.intermediates[idx - entry.inputs.size() - entry.outputs.size()];
} else {
NVF_CHECK(
0,
"Invalid buffer index: ",
idx,
" input size: ",
entry.inputs.size(),
" output size: ",
entry.outputs.size(),
" intermediate size: ",
entry.intermediates.size());
}
};
} // namespace
void KernelExecutor::computeArgs(
KernelExecutorEntry& entry,
const KernelArgumentHolder& args) const {
FUSER_PERF_SCOPE("KernelExecutor::computeArgs");
if (std::ssize(entry.args) != args.size()) {
entry.args.resize(args.size());
entry.arg_ptrs.resize(args.size());
}
NVF_ERROR_EQ(
args.size(), std::ssize(compiled_kernel_->kernel()->parameters()));
for (auto inp : compiled_kernel_->kernel()->inputs()) {
if (!inp->isA<TensorView>()) {
continue;
}
}
const PrimDataType idx_type = compiled_kernel_->kernel()->indexType();
int64_t buffer_info_idx = 0;
for (auto&& [arg_idx, arg] : enumerate(args)) {
if (arg.is<at::Tensor>() && arg.as<at::Tensor>().is_cuda()) {
const auto& buffer_info =
linear_buffer_info_getter(entry, buffer_info_idx++);
entry.args[arg_idx] = tensorToBytes(
arg,
buffer_info.shape_info.logical_sizes,
buffer_info.shape_info.allocation_strides.empty()
? buffer_info.shape_info.logical_strides
: buffer_info.shape_info.allocation_strides,
idx_type,
getLastDimAdjustment(buffer_info.tv->dtype()),
buffer_info.shape_info.unsharded_logical_sizes);
entry.arg_ptrs[arg_idx] = entry.args[arg_idx].data();
} else {
if (arg.is<at::Tensor>()) {
buffer_info_idx++;
}
auto bytes = polymorphicValueToBytes(
arg,
compiled_kernel_->kernel()->parameters()[arg_idx]->dtype(),
idx_type);
entry.args[arg_idx] = bytes;
entry.arg_ptrs[arg_idx] = entry.args[arg_idx].data();
}
}
}
int64_t KernelExecutor::getAvailableDynamicSmemSize() {
if (!available_dynamic_smem_size_.has_value()) {
int size = 0;
NVFUSER_CUDA_SAFE_CALL(cuFuncGetAttribute(
&size,
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
compiled_kernel_->cudaExecutable()->function));
available_dynamic_smem_size_ = size;
}
return available_dynamic_smem_size_.value();
}
int64_t KernelExecutor::getStaticSmemSize() {
if (!static_smem_size_.has_value()) {
int size = 0;
// Is this really a costly operation worth caching?
NVFUSER_CUDA_SAFE_CALL(cuFuncGetAttribute(
&size,
CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES,
compiled_kernel_->cudaExecutable()->function));
static_smem_size_ = size;
}
return static_smem_size_.value();
}
// TODO: Move to CompiledKernel
void KernelExecutor::validateDynamicSmemSize(int64_t dynamic_smem_size) {
// If specified, check that dynamic smem size matches what the scheduler
// expects
int64_t expected_dynamic_smem_size =
compiled_kernel_->kernel()->expectedDynamicSmemBytes();
if (expected_dynamic_smem_size >= 0) {
NVF_ERROR(
dynamic_smem_size == expected_dynamic_smem_size,
"Actual dynamic smem allocation ",
dynamic_smem_size,
" does not match expected size ",
expected_dynamic_smem_size);
}
NVF_ERROR(
getStaticSmemSize() + dynamic_smem_size <= device_smem_limit_,
"The total shared memory allocation is larger than available memory.",
" Dynamic size: ",
dynamic_smem_size,
". Static size: ",
getStaticSmemSize(),
". Required total size: ",
getStaticSmemSize() + dynamic_smem_size,
". Device limit size: ",
device_smem_limit_);
}
// TODO: Move to CompiledKernel
int64_t KernelExecutor::ensureAvailableDynamicSmemSize(
int64_t dynamic_smem_size) {
NVF_ERROR(
compiled_kernel_->isCompiled(),
"Cannot set dynamic smem size unless kernel is compiled");
if (dynamic_smem_size > getAvailableDynamicSmemSize()) {
validateDynamicSmemSize(dynamic_smem_size);
NVFUSER_CUDA_SAFE_CALL(cuFuncSetAttribute(
compiled_kernel_->cudaExecutable()->function,
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
dynamic_smem_size));
available_dynamic_smem_size_ = dynamic_smem_size;
}
return getAvailableDynamicSmemSize();
}
// TODO: Move to CompiledKernel
void KernelExecutor::resetCompiledKernelProperties() {
available_dynamic_smem_size_.reset();
static_smem_size_.reset();
}
namespace {
KernelArgumentHolder resolveRNGSeed(
const kir::Kernel* kernel,
KernelArgumentHolder& args) {
ExpressionEvaluator expr_eval;
KernelArgumentHolder resolved_args;
resolved_args.reserve(args.size());
int64_t arg_idx = 0;
for (auto param : kernel->parameters()) {
if (param->definition() &&
param->definition()->isA<kir::GetRNGSeedAndOffsetFromHost>()) {
resolved_args.push(expr_eval.evaluate(param));
} else {
resolved_args.push(args[arg_idx++]);
}
}
return resolved_args;
}
} // namespace
// TODO: Reduce bindings to only those necessary to resolve missing params.
// TODO: Check if this could be reused to also resolve dynamic aliases.
KernelArgumentHolder KernelExecutor::resolveTMA(
KernelExecutorEntry& entry,
const KernelArgumentHolder& args) const {
ExpressionEvaluator expr_eval;
int64_t arg_idx = 0;
NVF_ERROR(
entry.inputs.size() == compiled_kernel_->kernel()->inputs().size(),
"Input size mismatch");
for (auto inp_idx : arange(entry.inputs.size())) {
expr_eval.bind(
compiled_kernel_->kernel()->inputs()[inp_idx], args[arg_idx++]);
}
NVF_ERROR(
entry.outputs.size() == compiled_kernel_->kernel()->outputs().size(),
"Output size mismatch");
for (auto out_idx : arange(entry.outputs.size())) {
expr_eval.bind(
compiled_kernel_->kernel()->outputs()[out_idx], args[arg_idx++]);