Skip to content

fixing fusion segmenter dropping allocation domain#1033

Merged
jjsjann123 merged 17 commits intomainfrom
segmenter_allocation_domain_patch
Oct 11, 2023
Merged

fixing fusion segmenter dropping allocation domain#1033
jjsjann123 merged 17 commits intomainfrom
segmenter_allocation_domain_patch

Conversation

@jjsjann123
Copy link
Collaborator

Fixes #1021

Copy link
Collaborator

@jacobhinkle jacobhinkle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. It seems like this implementation will work for stride order, but not if there are Exprs between alloc and rfactor. In that case we would need a replay.

@jjsjann123
Copy link
Collaborator Author

Good idea. It seems like this implementation will work for stride order, but not if there are Exprs between alloc and rfactor. In that case we would need a replay.

Good point.
I thought replay can only be applied towards either a producer or consumer. Do we have something that's like
new_alloc = replayAs(old_alloc, old_rfactor, new_root)

@jjsjann123
Copy link
Collaborator Author

!build

@jacobhinkle
Copy link
Collaborator

Good idea. It seems like this implementation will work for stride order, but not if there are Exprs between alloc and rfactor. In that case we would need a replay.

Good point.
I thought replay can only be applied towards either a producer or consumer. Do we have something that's like
new_alloc = replayAs(old_alloc, old_rfactor, new_root)

I am not 100% sure but I was thinking BestEffortReplay could do this.

@xwang233
Copy link
Collaborator

xwang233 commented Oct 6, 2023

!build

2 similar comments
@xwang233
Copy link
Collaborator

xwang233 commented Oct 6, 2023

!build

@xwang233
Copy link
Collaborator

xwang233 commented Oct 6, 2023

!build

@naoyam
Copy link
Collaborator

naoyam commented Oct 6, 2023

Good idea. It seems like this implementation will work for stride order, but not if there are Exprs between alloc and rfactor. In that case we would need a replay.

Good point.
I thought replay can only be applied towards either a producer or consumer. Do we have something that's like
new_alloc = replayAs(old_alloc, old_rfactor, new_root)

I am not 100% sure but I was thinking BestEffortReplay could do this.

That would be ReplayTransformations, but again do we really need this?

@jjsjann123
Copy link
Collaborator Author

CI is all green. what's the parse job failure that was reported?!?! @xwang233

@xwang233
Copy link
Collaborator

xwang233 commented Oct 6, 2023

That's from an old run (with the same commit), since Github doesn't allow deleting commit status. You may ignore that. Sorry for the confusions.

@jjsjann123
Copy link
Collaborator Author

!build

@jjsjann123
Copy link
Collaborator Author

errr. nuking it didn't work.

I'm seeing failures on:
NVFuserTest.ResizePadToBroadcastDynamic_CUDA
NVFuserTest.FusionReshapeReductionShmoo_CUDA

and the second test failure actually killed the node. 😠

@jjsjann123
Copy link
Collaborator Author

lol, the error message gives me headache.
Looks like it failed to infer launch params.

https://gitlab-master.nvidia.com/dl/pytorch/fuser-gh-mirror/-/jobs/71022350/raw

Exception in thread pool task: val.hasValue() INTERNAL ASSERT FAILED at
...
Exception raised from computeLaunchParams at /opt/pytorch/nvfuser/csrc/executor.cpp:1113 (most recent call first)

@jjsjann123
Copy link
Collaborator Author

Here's the failing test fusion printMath.

Inputs:
  T0_g[ iS0{i0}, iS1{i2}, iS2{i3}, iS3{i4}, iS4{i5} ], float
  T1_g[ iS5{i6}, iS6{i7}, iS7{i8}, iS8{i9}, iS9{i10} ], float
  i11, int64_t
  i12, int64_t
  i13, int64_t
  i14, int64_t
  i15, int64_t
  i16, int64_t
  i17, int64_t
  i18, int64_t
Outputs:
  T3_g[ iS19{i6}, iS20{i7}, iS21{i8}, iS22{i9}, iS23{i10} ], float

%kernel_math {
i20 = (nvfuser_index_t)(i17);
i22 = (nvfuser_index_t)(i18);
i24 = (nvfuser_index_t)(i15);
i26 = (nvfuser_index_t)(i16);
i28 = (nvfuser_index_t)(i13);
i30 = (nvfuser_index_t)(i14);
i32 = (nvfuser_index_t)(i11);
i34 = (nvfuser_index_t)(i12);
T2_g[ iS10{i0}, bS24{( ( i2 + ( (nvfuser_index_t)(i17) ) ) + ( (nvfuser_index_t)(i18) ) )}rf, iS25{( ( i3 + ( (nvfuser_index_t)(i15) ) ) + ( (nvfuser_index_t)(i16) ) )}rf, bS26{( ( i4 + ( (nvfuser_index_t)(i13) ) ) + ( (nvfuser_index_t)(i14) ) )}rf, iS27{( ( i5 + ( (nvfuser_index_t)(i11) ) ) + ( (nvfuser_index_t)(i12) ) )}rf ]
   = pad( T0_g[ iS0{i0}, iS1{i2}, iS2{i3}, iS3{i4}, iS4{i5} ], {0, 0, i20, i22, i24, i26, i28, i30, i32, i34} )
T3_g[ iS19{i6}, iS20{i7}, iS21{i8}, iS22{i9}, iS23{i10} ]
   = T1_g[ iS5{i6}, iS6{i7}, iS7{i8}, iS8{i9}, iS9{i10} ]
   * T2_g[ iS10{i0}, bS24{( ( i2 + ( (nvfuser_index_t)(i17) ) ) + ( (nvfuser_index_t)(i18) ) )}rf, iS25{( ( i3 + ( (nvfuser_index_t)(i15) ) ) + ( (nvfuser_index_t)(i16) ) )}rf, bS26{( ( i4 + ( (nvfuser_index_t)(i13) ) ) + ( (nvfuser_index_t)(i14) ) )}rf, iS27{( ( i5 + ( (nvfuser_index_t)(i11) ) ) + ( (nvfuser_index_t)(i12) ) )}rf ];
}

we are segmenting the pointwise op after the pad into a separate kernel.
This is the output when we are keeping the rfactor on input (T2_g).

(gdb) print fusion->printMath(true)
Inputs:
  T1_g[ iS101{( ceilDiv(( ceilDiv(( ceilDiv(( i6 * ( i7 * ( i8 * ( i9 * i10 ) ) ) ), 128) ), 1) ), 1) )}, iS102{1}, iS100{1}, iS98{128} ], float
  T2_g[ iS81{( ceilDiv(( ceilDiv(( ceilDiv(( i0 * ( ( ( i2 + ( (nvfuser_index_t)(i17) ) ) + ( (nvfuser_index_t)(i18) ) ) * ( ( ( i3 + ( (nvfuser_index_t)(i15) ) ) + ( (nvfuser_index_t)(i16) ) ) * ( ( ( i4 + ( (nvfuser_index_t)(i13) ) ) + ( (nvfuser_index_t)(i14) ) ) * ( ( i5 + ( (nvfuser_index_t)(i11) ) ) + ( (nvfuser_index_t)(i12) ) ) ) ) ) ), 128) ), 1) ), 1) )}, iS82{1}, iS80{1}, iS78{128} ], float
Outputs:
  T3_g[ iblockIdx.x51{( ceilDiv(( ceilDiv(( ceilDiv(( i6 * ( i7 * ( i8 * ( i9 * i10 ) ) ) ), 128) ), 1) ), 1) )}, iUS52{1}, iS50{1}, ithreadIdx.x48{128} ] ca_pos( 2 ) produce_pos( 4 ), float

%kernel_math {
T4_l[ iblockIdx.x91{( ceilDiv(( ceilDiv(( ceilDiv(( i6 * ( i7 * ( i8 * ( i9 * i10 ) ) ) ), 128) ), 1) ), 1) )}, iUS92{1}, iS90{1}, ithreadIdx.x88{128} ] ca_pos( 2 )
   = Set( T1_g[ iS101{( ceilDiv(( ceilDiv(( ceilDiv(( i6 * ( i7 * ( i8 * ( i9 * i10 ) ) ) ), 128) ), 1) ), 1) )}, iS102{1}, iS100{1}, iS98{128} ], cache_op=Streaming )
T5_l[ iblockIdx.x71{( ceilDiv(( ceilDiv(( ceilDiv(( i0 * ( ( ( i2 + ( (nvfuser_index_t)(i17) ) ) + ( (nvfuser_index_t)(i18) ) ) * ( ( ( i3 + ( (nvfuser_index_t)(i15) ) ) + ( (nvfuser_index_t)(i16) ) ) * ( ( ( i4 + ( (nvfuser_index_t)(i13) ) ) + ( (nvfuser_index_t)(i14) ) ) * ( ( i5 + ( (nvfuser_index_t)(i11) ) ) + ( (nvfuser_index_t)(i12) ) ) ) ) ) ), 128) ), 1) ), 1) )}, iUS72{1}, iS70{1}, ithreadIdx.x68{128} ] ca_pos( 2 )
   = Set( T2_g[ iS81{( ceilDiv(( ceilDiv(( ceilDiv(( i0 * ( ( ( i2 + ( (nvfuser_index_t)(i17) ) ) + ( (nvfuser_index_t)(i18) ) ) * ( ( ( i3 + ( (nvfuser_index_t)(i15) ) ) + ( (nvfuser_index_t)(i16) ) ) * ( ( ( i4 + ( (nvfuser_index_t)(i13) ) ) + ( (nvfuser_index_t)(i14) ) ) * ( ( i5 + ( (nvfuser_index_t)(i11) ) ) + ( (nvfuser_index_t)(i12) ) ) ) ) ) ), 128) ), 1) ), 1) )}, iS82{1}, iS80{1}, iS78{128} ], cache_op=AllLevels )
T6_l[ iblockIdx.x61{( ceilDiv(( ceilDiv(( ceilDiv(( i6 * ( i7 * ( i8 * ( i9 * i10 ) ) ) ), 128) ), 1) ), 1) )}, iUS62{1}, iS60{1}, ithreadIdx.x58{128} ] ca_pos( 4 ) produce_pos( 2 )
   = T4_l[ iblockIdx.x91{( ceilDiv(( ceilDiv(( ceilDiv(( i6 * ( i7 * ( i8 * ( i9 * i10 ) ) ) ), 128) ), 1) ), 1) )}, iUS92{1}, iS90{1}, ithreadIdx.x88{128} ] ca_pos( 2 )
   * T5_l[ iblockIdx.x71{( ceilDiv(( ceilDiv(( ceilDiv(( i0 * ( ( ( i2 + ( (nvfuser_index_t)(i17) ) ) + ( (nvfuser_index_t)(i18) ) ) * ( ( ( i3 + ( (nvfuser_index_t)(i15) ) ) + ( (nvfuser_index_t)(i16) ) ) * ( ( ( i4 + ( (nvfuser_index_t)(i13) ) ) + ( (nvfuser_index_t)(i14) ) ) * ( ( i5 + ( (nvfuser_index_t)(i11) ) ) + ( (nvfuser_index_t)(i12) ) ) ) ) ) ), 128) ), 1) ), 1) )}, iUS72{1}, iS70{1}, ithreadIdx.x68{128} ] ca_pos( 2 );
T3_g[ iblockIdx.x51{( ceilDiv(( ceilDiv(( ceilDiv(( i6 * ( i7 * ( i8 * ( i9 * i10 ) ) ) ), 128) ), 1) ), 1) )}, iUS52{1}, iS50{1}, ithreadIdx.x48{128} ] ca_pos( 2 ) produce_pos( 4 )
   = Set( T6_l[ iblockIdx.x61{( ceilDiv(( ceilDiv(( ceilDiv(( i6 * ( i7 * ( i8 * ( i9 * i10 ) ) ) ), 128) ), 1) ), 1) )}, iUS62{1}, iS60{1}, ithreadIdx.x58{128} ] ca_pos( 4 ) produce_pos( 2 ), cache_op=Streaming )
}

And this is what we have when we clean up the rfactor domain on inputs.

(gdb) print fusion->printMath(true)
Inputs:
  T1_g[ iS115{( ceilDiv(( ceilDiv(( ceilDiv(( i6 * ( i7 * ( i8 * ( i9 * i10 ) ) ) ), 128) ), 1) ), 1) )}, iS116{1}, iS114{1}, iS112{128} ], float
  T2_g[ iS95{( ceilDiv(( ceilDiv(( ceilDiv(( i0 * ( i480 * ( i481 * ( i482 * i483 ) ) ) ), 128) ), 1) ), 1) )}, iS96{1}, iS94{1}, iS92{128} ], float
Outputs:
  T3_g[ iblockIdx.x65{( ceilDiv(( ceilDiv(( ceilDiv(( i6 * ( i7 * ( i8 * ( i9 * i10 ) ) ) ), 128) ), 1) ), 1) )}, iUS66{1}, iS64{1}, ithreadIdx.x62{128} ] ca_pos( 2 ) produce_pos( 4 ), float

%kernel_math {
T4_l[ iblockIdx.x105{( ceilDiv(( ceilDiv(( ceilDiv(( i6 * ( i7 * ( i8 * ( i9 * i10 ) ) ) ), 128) ), 1) ), 1) )}, iUS106{1}, iS104{1}, ithreadIdx.x102{128} ] ca_pos( 2 )
   = Set( T1_g[ iS115{( ceilDiv(( ceilDiv(( ceilDiv(( i6 * ( i7 * ( i8 * ( i9 * i10 ) ) ) ), 128) ), 1) ), 1) )}, iS116{1}, iS114{1}, iS112{128} ], cache_op=Streaming )
T5_l[ iblockIdx.x85{( ceilDiv(( ceilDiv(( ceilDiv(( i0 * ( i480 * ( i481 * ( i482 * i483 ) ) ) ), 128) ), 1) ), 1) )}, iUS86{1}, iS84{1}, ithreadIdx.x82{128} ] ca_pos( 2 )
   = Set( T2_g[ iS95{( ceilDiv(( ceilDiv(( ceilDiv(( i0 * ( i480 * ( i481 * ( i482 * i483 ) ) ) ), 128) ), 1) ), 1) )}, iS96{1}, iS94{1}, iS92{128} ], cache_op=AllLevels )
T6_l[ iblockIdx.x75{( ceilDiv(( ceilDiv(( ceilDiv(( i6 * ( i7 * ( i8 * ( i9 * i10 ) ) ) ), 128) ), 1) ), 1) )}, iUS76{1}, iS74{1}, ithreadIdx.x72{128} ] ca_pos( 4 ) produce_pos( 2 )
   = T4_l[ iblockIdx.x105{( ceilDiv(( ceilDiv(( ceilDiv(( i6 * ( i7 * ( i8 * ( i9 * i10 ) ) ) ), 128) ), 1) ), 1) )}, iUS106{1}, iS104{1}, ithreadIdx.x102{128} ] ca_pos( 2 )
   * T5_l[ iblockIdx.x85{( ceilDiv(( ceilDiv(( ceilDiv(( i0 * ( i480 * ( i481 * ( i482 * i483 ) ) ) ), 128) ), 1) ), 1) )}, iUS86{1}, iS84{1}, ithreadIdx.x82{128} ] ca_pos( 2 );
T3_g[ iblockIdx.x65{( ceilDiv(( ceilDiv(( ceilDiv(( i6 * ( i7 * ( i8 * ( i9 * i10 ) ) ) ), 128) ), 1) ), 1) )}, iUS66{1}, iS64{1}, ithreadIdx.x62{128} ] ca_pos( 2 ) produce_pos( 4 )
   = Set( T6_l[ iblockIdx.x75{( ceilDiv(( ceilDiv(( ceilDiv(( i6 * ( i7 * ( i8 * ( i9 * i10 ) ) ) ), 128) ), 1) ), 1) )}, iUS76{1}, iS74{1}, ithreadIdx.x72{128} ] ca_pos( 4 ) produce_pos( 2 ), cache_op=Streaming )
}

So I compared the before and after on that kernel. Looks like with the rfactor thing in place, the input tensor view seems to be carrying something coming from the previous pad operation?!

@jjsjann123
Copy link
Collaborator Author

I put a separate issue to track the suggested removal #1040

For this PR. I'll revert that and refactor to use the replay suggested by Jacob/Naoya.

@jjsjann123
Copy link
Collaborator Author

!build

@jacobhinkle
Copy link
Collaborator

Looks like with the rfactor thing in place, the input tensor view seems to be carrying something coming from the previous pad operation?!

This is a little bit tricky. See #840. That carryover scalar comes from concretization where we currently are modifying extents downstream of dynamic reshapes. Note also #912 which would preserve extents during concretization instead but which also hits this issue on the first segment instead of the second one; that one is relatively easy to fix by just including the output sizes in the inputs() to ViewOp.

For #840 I need to fix the traversal before merging since currently it includes scalars that are already included as extents of input TVs. Once that is done these tests should pass again.

@jjsjann123
Copy link
Collaborator Author

IIUC, we can just wait until your fixes are in and re-evaluate nuking the function.

Meanwhile, does this PR serve as a reasonable workaround at this time so we can move forward with our allocation domain plumbing on inputs?

@jacobhinkle
Copy link
Collaborator

IIUC, we can just wait until your fixes are in and re-evaluate nuking the function.

Meanwhile, does this PR serve as a reasonable workaround at this time so we can move forward with our allocation domain plumbing on inputs?

Yes. I think so. We can address #1040 once #840 goes through, or in conjunction. Thanks!

@naoyam
Copy link
Collaborator

naoyam commented Oct 9, 2023

@jjsjann123 @jacobhinkle Another potential error would be misusing root domains in place of rfactor domains. In early days we tended to think only about root domains even when rfactor should be used. There may be still some code for fusion/segment inputs as they are (currently) guaranteed to not have rfactor domains.

new_root_domain, tv->domain()->contiguity());

if (tv->domain()->hasAllocation()) {
// we need to replay the new root domain following the old rfactor domain
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add tests that actually require this replay? The current one only reorders the domains.

Copy link
Collaborator Author

@jjsjann123 jjsjann123 Oct 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turns out testing this is tricky, since we don't have convertInputRfactorsToRoots exposed.

So we only hit this function when we go through automatic scheduler, which triggers segmentation. But automatic scheduler do not support complex transforms on allocation domain.

The reason is that, setAllocationDomain has a check that requires allocation domain "consistent" with leaf domain (read nvfuser::ir_utils::validateDomainEquivalence). So in order to set a merge or something like that on the allocation domain, I need to apply the same transformation on leaf.
But we are not supposed to hand a transformed fusion to automatic scheduler.

Here's the issue when I try to run this example.

  auto tv0 = TensorViewBuilder().ndims(3).shape({-1, -1, -1}).build();
  fusion->addInput(tv0);
  auto s0 = IrBuilder::create<Val>(5, DataType::Float);
  auto tv1 = add(tv0, s0); 
  fusion->addOutput(tv1);

  // tv0->merge(0, 1); // sorry this code shouldn't have been here. It's copied after I started changing the code to manual scheduling.

  std::vector<IterDomain*> alloc_domain = {tv0->axis(2)};
  IterDomain* merged_id = 
      IterDomainBuilder(fusion->zeroVal(), mul(tv0->axis(0)->extent(), tv0->axis(1)->extent()))
          .parallel_type(tv0->axis(0)->getParallelType())
          .iter_type(tv0->axis(0)->getIterType())
          .build();
  IrBuilder::create<Merge>(tv0->axis(0)->container(), merged_id, tv0->axis(0), tv0->axis(1));
  alloc_domain.push_back(merged_id);

  tv0->setAllocationDomain(alloc_domain, {false, true});
C++ exception with description "derived_domain_ == frontier_ INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/ir/utils.cpp":877, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Invalid derived domain. Initial domain: iS2{i2}, iS3{( i0 * i1 )}. Derived domain: iS0{i0}, iS1{i1}, iS2{i2}

If we really want to test this code path with transformation, I guess we can expose this function directly. But that feels a bit too much. tagging @naoyam for a second opinion.

Copy link
Collaborator

@naoyam naoyam Oct 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a typo here?: std::vector<IterDomain*> alloc_domain = {tv0->axis(2)};. Since tv0 is merged to a 2D tensor, tv0->axis(2) should not work, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I copied the wrong code.... I was trying to manually schedule it after running into the error.
Then I realized that manual scheduling won't hit the same code path, unless I expose the function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you look at this

  std::vector<IterDomain*> alloc_domain = {tv0->axis(2)};
  IterDomain* merged_id = 
      IterDomainBuilder(fusion->zeroVal(), mul(tv0->axis(0)->extent(), tv0->axis(1)->extent()))
          .parallel_type(tv0->axis(0)->getParallelType())
          .iter_type(tv0->axis(0)->getIterType())
          .build();
  IrBuilder::create<Merge>(tv0->axis(0)->container(), merged_id, tv0->axis(0), tv0->axis(1));
  alloc_domain.push_back(merged_id);

You can see that I was manually merging the two iterdomain.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure what you're trying to do here, but the error is because of the manual merge, which doesn't update the leaf domain. The check in setAllocationDomain is to make sure the allocation domain sits between the root and leaf domains, and now the code is trying to set an allocation domain that's outside of the path.

Conceptually, this should be completely fine since the new allocation domain is equivalent to the leaf domain. However, I don't think the current indexing works with that since it basically assumes the allocation domain consists of IterDomains that are ascendants of leaf IterDomains. Fundamentally, it should be possible to traverse from leaf IterDomains to whatever allocation domains, but that's not how it works at this moment.

For this PR, can we just set up a tensor with an allocation domain that sits between root and rfactor domains? Isn't there a test doing that already?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

both my attempts at getting the cpp test failed... got into new codegen issue. Opened #1047 #1048 to track those.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😅

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for trying it out. Looks like there are still many things we need to do in the scheduler, as expected.

For the meantime, please add an assertion that the allocation domain is just an reordered rfactor domain. Or we could expose this function as public so that we could test it without going through the scheduler/segmenter. Either seems fine to me.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add an assertion for now. since that's easier to remove later when we try to fix other utils in the scheduler.

@jjsjann123 jjsjann123 requested a review from naoyam October 10, 2023 21:36
@jjsjann123
Copy link
Collaborator Author

rolled everything back to the original commit without the replay refactor.
😢

@jjsjann123
Copy link
Collaborator Author

!build

@jjsjann123
Copy link
Collaborator Author

failing CI is irrelevant and patched here: #1058

@jjsjann123
Copy link
Collaborator Author

!build

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but why there're lots of diff changes? False alarms?

@jjsjann123
Copy link
Collaborator Author

LGTM, but why there're lots of diff changes? False alarms?

hmmm. I'm surprised that code changes here causes code diff... I'll try to take a look at that before merging this one.

@jjsjann123
Copy link
Collaborator Author

00:05:13 FAILED: CMakeFiles/nvfuser_tests.dir/test/test_gpu_fused_reduction.cpp.o 
00:05:13 /usr/bin/c++ -DUSE_C10D_GLOO -DUSE_C10D_NCCL -DUSE_DISTRIBUTED -DUSE_GTEST -DUSE_RPC -DUSE_TENSORPIPE -I/opt/pytorch/nvfuser/cmake/../third_party/benchmark/include -I/opt/pytorch/nvfuser -I/opt/pytorch/nvfuser/csrc -I/opt/pytorch/nvfuser/lib/dynamic_type/src -I/opt/pytorch/nvfuser/third_party/flatbuffers/include -isystem /opt/pytorch/nvfuser/cmake/../third_party/googletest/googlemock/include -isystem /opt/pytorch/nvfuser/cmake/../third_party/googletest/googletest/include -isystem /opt/pytorch/nvfuser/cmake/../third_party/flatbuffers/include -isystem /opt/pytorch/nvfuser/third_party/googletest/googletest/include -isystem /opt/pytorch/nvfuser/third_party/googletest/googletest -isystem /opt/pytorch/nvfuser/third_party/googletest/googlemock/include -isystem /opt/pytorch/nvfuser/third_party/googletest/googlemock -isystem /usr/local/lib/python3.10/site-packages/torch/include -isystem /usr/local/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda/include -Wno-psabi -D_GLIBCXX_USE_CXX11_ABI=0 -O3 -DNDEBUG -Wall -Wno-unused-function -D_GLIBCXX_USE_CXX11_ABI=0 -Werror -MD -MT CMakeFiles/nvfuser_tests.dir/test/test_gpu_fused_reduction.cpp.o -MF CMakeFiles/nvfuser_tests.dir/test/test_gpu_fused_reduction.cpp.o.d -o CMakeFiles/nvfuser_tests.dir/test/test_gpu_fused_reduction.cpp.o -c /opt/pytorch/nvfuser/test/test_gpu_fused_reduction.cpp
00:05:13 c++: fatal error: Killed signal terminated program cc1plus

OOM on those CI machines?

I'm also seeing some diff coming from here https://gitlab-master.nvidia.com/dl/pytorch/fuser-gh-mirror/-/jobs/71309358 which the CI seems to be unable to find the diff file?!

cc'ing @xwang233

@xwang233
Copy link
Collaborator

00:05:24 Using 256 jobs for compilation

Probably the build concurrency was too high. I've added a limit on the concurrency. Feel free to restart a build.

@zasdfgbnm
Copy link
Collaborator

00:05:24 Using 256 jobs for compilation

Probably the build concurrency was too high. I've added a limit on the concurrency. Feel free to restart a build.

Should we increase this value?

Fuser/setup.py

Line 318 in ecc9123

mem_gb_per_task = 3 # Currently compilation of nvFuser souce code takes ~3GB of memory per task, we should adjust this value if it changes in the future.

@jjsjann123
Copy link
Collaborator Author

I'm gonna ignore the issue then since it feels low risk for this PR to impact generated code.
Merging as-is

@jjsjann123 jjsjann123 merged commit 747a09a into main Oct 11, 2023
@jjsjann123 jjsjann123 deleted the segmenter_allocation_domain_patch branch October 11, 2023 20:02
@jjsjann123 jjsjann123 added the allocation domain issues related to allocation domain support label Oct 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

allocation domain issues related to allocation domain support

Projects

None yet

Development

Successfully merging this pull request may close these issues.

(maybe) indexing issue gives wrong result

5 participants