Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ All tests support the same set of arguments :
* `-m,--agg_iters <aggregation count>` number of operations to aggregate together in each iteration. Default : 1.
* `-a,--average <0/1/2/3>` Report performance as an average across all ranks (MPI=1 only). <0=Rank0,1=Avg,2=Min,3=Max>. Default : 1.
* Test operation
* `-s,--setup_file <filename>` Read parameters from file for tests that require it. Currently only required for alltoallv benchmark. Default : disabled. Max of 64 characters for filename.
* `-p,--parallel_init <0/1>` use threads to initialize NCCL in parallel. Default : 0.
* `-c,--check <0/1>` check correctness of results. This can be quite slow on large numbers of GPUs. Default : 1.
* `-z,--blocking <0/1>` Make NCCL collective blocking, i.e. have CPUs wait and sync after each collective. Default : 0.
Expand Down
1 change: 1 addition & 0 deletions paramfiles/alltoallv_paramfiles/Rank1Test1.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
1
2 changes: 2 additions & 0 deletions paramfiles/alltoallv_paramfiles/Rank2Test1.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
0.1,0.4
0,0.5
2 changes: 2 additions & 0 deletions paramfiles/alltoallv_paramfiles/Rank2Test2.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
0,1
0,0
3 changes: 3 additions & 0 deletions paramfiles/alltoallv_paramfiles/Rank3Test1.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
0.3,0.4,0.3
0.2,0,0.8
0.1,0.2,0.7
3 changes: 3 additions & 0 deletions paramfiles/alltoallv_paramfiles/Rank3Test2.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
0,0,1
0,1,0
1,0,0
4 changes: 4 additions & 0 deletions paramfiles/alltoallv_paramfiles/Rank4Test1.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
1,1,1,1
1,1,1,1
1,1,1,1
1,1,1,1
4 changes: 4 additions & 0 deletions paramfiles/alltoallv_paramfiles/Rank4Test2.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
0.25,0.25,0.25,0.25
0.50,0.50,0.50,0.50
0.75,0.75,0.75,0.75
1,1,1,1
4 changes: 4 additions & 0 deletions paramfiles/alltoallv_paramfiles/Rank4Test3.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
1,1,1,1
0.1,0.1,0.1,0.1
0.1,0.1,0.1,0.1
0.1,0.1,0.1,0.1
4 changes: 4 additions & 0 deletions paramfiles/alltoallv_paramfiles/Rank4Test4.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
1,1,1,1
0,0,0,0
0,0,0,0
0,0,0,0
4 changes: 4 additions & 0 deletions paramfiles/alltoallv_paramfiles/Rank4Test5.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
0,0,0,0
1,1,1,1
1,1,1,1
1,1,1,1
5 changes: 5 additions & 0 deletions paramfiles/alltoallv_paramfiles/Rank4Test6.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
1,0,0,0
1,0,0,0
1,0,0,0
1,0,0,0

2 changes: 1 addition & 1 deletion src/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ NVLDFLAGS += $(LIBRARIES:%=-l%)
DST_DIR := $(BUILDDIR)
SRC_FILES := $(wildcard *.cu)
OBJ_FILES := $(SRC_FILES:%.cu=${DST_DIR}/%.o)
BIN_FILES_LIST := all_reduce all_gather broadcast reduce_scatter reduce alltoall scatter gather sendrecv hypercube
BIN_FILES_LIST := all_reduce all_gather broadcast reduce_scatter reduce alltoall alltoallv scatter gather sendrecv hypercube
BIN_FILES := $(BIN_FILES_LIST:%=${DST_DIR}/%_perf)

build: ${BIN_FILES}
Expand Down
2 changes: 1 addition & 1 deletion src/all_gather.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void AllGatherGetBw(size_t count, int typesize, double sec, double* algBw, doubl
*busBw = baseBw * factor;
}

testResult_t AllGatherRunColl(void* sendbuff, void* recvbuff, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
testResult_t AllGatherRunColl(void* sendbuff, void* recvbuff, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream, struct threadArgs* args) {
NCCLCHECK(ncclAllGather(sendbuff, recvbuff, count, type, comm, stream));
return testSuccess;
}
Expand Down
2 changes: 1 addition & 1 deletion src/all_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ void AllReduceGetBw(size_t count, int typesize, double sec, double* algBw, doubl
*busBw = baseBw * factor;
}

testResult_t AllReduceRunColl(void* sendbuff, void* recvbuff, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
testResult_t AllReduceRunColl(void* sendbuff, void* recvbuff, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream, struct threadArgs* args) {
NCCLCHECK(ncclAllReduce(sendbuff, recvbuff, count, type, op, comm, stream));
return testSuccess;
}
Expand Down
2 changes: 1 addition & 1 deletion src/alltoall.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void AlltoAllGetBw(size_t count, int typesize, double sec, double* algBw, double
*busBw = baseBw * factor;
}

testResult_t AlltoAllRunColl(void* sendbuff, void* recvbuff, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
testResult_t AlltoAllRunColl(void* sendbuff, void* recvbuff, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream, struct threadArgs* args) {
int nRanks;
NCCLCHECK(ncclCommCount(comm, &nRanks));
size_t rankOffset = count * wordSize(type);
Expand Down
170 changes: 170 additions & 0 deletions src/alltoallv.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#include <iostream>
#include <fstream>
#include <sstream>
#include <vector>
#include "cuda_runtime.h"
#include "common.h"

int CHECK = 0;

/**
* @brief Parses the parameter file and stores the matrix data into the imbalancingFactors reference passed in.
* @param nranks The number of ranks in the test
* @param imbalancingFactors The reference to the vector that will store the parsed data
* @param filename The name of the parameter file to parse
**/
testResult_t parseParamFile(int nranks, std::vector<std::vector<double>> &imbalancingFactors, const char filename[PATH_MAX]){
std::vector<std::vector<double>> paramFile_data;
std::ifstream paramFile(filename);

if (!paramFile.is_open()) {
PRINT("\nUNABLE TO OPEN PARAMS FILE AT: %s\n", filename);
return testInternalError;
}

std::string row;
int rowidx = 0;
while(std::getline(paramFile,row)){ //iterate over every row
std::vector<double> values; //values from this line
std::stringstream rowstream(row);
std::string value;
while(std::getline(rowstream,value,',')){ //go over the row and get each value
double dval = std::stod(value);
if(dval<0 || dval>1) {
PRINT("\nINVALID PARAMS FILE, PARAMETER OUT OF 0:1 RANGE, ROW NUMBER: %i \n", rowidx);
return testInternalError;
} //ensure that the value is between 0 and 1 (necessary for probability distribution)
values.push_back(dval);
}
if(values.size()!=nranks) {
PRINT("\nINVALID PARAMS FILE, ROW %i DOES NOT HAVE CORRECT NUMBER OF VALUES, HAS %lu ENTRIES, NEEDS %i ENTRIES\n", rowidx, values.size(), nranks);
return testInternalError;
}//ensure that this row has the right amount of values
paramFile_data.push_back(values);
rowidx++;
}

if(paramFile_data.size()!=nranks) {
PRINT("\nINVALID PARAMS FILE, DOES NOT HAVE CORRECT NUMBER OF ROWS, HAS %i ROWS, NEEDS %i ROWS\n", paramFile_data.size(), nranks);
return testInternalError;
} //ensure we have the right amount of rows

imbalancingFactors = paramFile_data; //store the data in the return variable
return testSuccess;
}
void AlltoAllvGetCollByteCount(size_t *sendcount, size_t *recvcount, size_t *paramcount, size_t *sendInplaceOffset, size_t *recvInplaceOffset, size_t count, int nranks) {
*sendcount = (count/nranks)*nranks; //Total send count rounded to a multiple of ranks
*recvcount = (count/nranks)*nranks; //Total recv count rounded to a multiple of ranks
*sendInplaceOffset = 0;
*recvInplaceOffset = 0;
*paramcount = (count/nranks); //Each rank can send a maximum of count/nranks data to each other rank
}

testResult_t AlltoAllvInitData(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t op, int root, int rep, int in_place) {
size_t maxchunk = args->nbytes / wordSize(type);
int nranks = args->nProcs*args->nThreads*args->nGpus;
//parse the param file
std::vector<std::vector<double>> imbalancingFactors;
testResult_t parseSuccess = parseParamFile(nranks, imbalancingFactors, args->setup_file);
if(parseSuccess != testSuccess) return parseSuccess;
for (int i=0; i<args->nGpus; i++) {
CUDACHECK(cudaSetDevice(args->gpus[i]));
CUDACHECK(cudaMemset(args->recvbuffs[i], 0, args->expectedBytes)); //zeroes out the receive buffer of each GPU with total size (recvcount*wordSize(type))
CUDACHECK(cudaMemcpy(args->expected[i], args->recvbuffs[i], args->expectedBytes, cudaMemcpyDefault)); //copies the zeroed out receive buffer to the expected buffer
int rank = ((args->proc*args->nThreads + args->thread)*args->nGpus + i); //current rank
void* data = in_place ? args->recvbuffs[i] : args->sendbuffs[i];
TESTCHECK(InitData(data, maxchunk*nranks, 0, type, ncclSum, 33*rep + rank, 1, 0)); //initializes the sendbuffer data for this rank. Should be chunk size * nranks
for (int j=0; j<nranks; j++) {
size_t partcount_mod = maxchunk * imbalancingFactors[j][rank]; //imbalance the count of data to initialize same way we do in the test
TESTCHECK(InitData((char*)args->expected[i] + j*maxchunk*wordSize(type), partcount_mod, rank*maxchunk, type, ncclSum, 33*rep + j, 1, 0));
}
CUDACHECK(cudaDeviceSynchronize());
}
// We don't support in-place alltoallv
args->reportErrors = in_place ? 0 : 1;
return testSuccess;
}

void AlltoAllvGetBw(size_t count, int typesize, double sec, double* algBw, double* busBw, int nranks) {
double baseBw = (double)(count * nranks * typesize) / 1.0E9 / sec;

*algBw = baseBw;
double factor = ((double)(nranks-1))/((double)(nranks));
*busBw = baseBw * factor;
}

testResult_t AlltoAllvRunColl(void* sendbuff, void* recvbuff, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream, struct threadArgs* args) {
int nRanks, myRank;
NCCLCHECK(ncclCommCount(comm, &nRanks));
NCCLCHECK(ncclCommUserRank(comm, &myRank));
std::vector<std::vector<double>> imbalancingFactors;
testResult_t parseSuccess = parseParamFile(nRanks, imbalancingFactors, args->setup_file); //parse the param file
if(parseSuccess != testSuccess) return parseSuccess;
size_t rankOffset = count * wordSize(type);

#if NCCL_MAJOR < 2 || NCCL_MINOR < 7
printf("NCCL 2.7 or later is needed for alltoallv. This test was compiled with %d.%d.\n", NCCL_MAJOR, NCCL_MINOR);
return testNcclError;
#else
NCCLCHECK(ncclGroupStart());
for (int r=0; r<nRanks; r++) {
// int count_mod = (count-myRank-r-1) % count; //modify the count variable to to be strictly less than count, but depend on both the peer rank and the sending rank
if(myRank>imbalancingFactors.size()){
PRINT("\nmyRank is greater than imbalancingFactors.size(), %i\n", myRank);
return testInternalError;
} else if (r > imbalancingFactors[myRank].size()) {
PRINT("\nr is greater than imbalancingFactors[myRank].size(), %i\n", r);
return testInternalError;
}
unsigned long send_count_mod = count * imbalancingFactors[myRank][r];
unsigned long recv_count_mod = count * imbalancingFactors[r][myRank];
NCCLCHECK(ncclSend(((char*)sendbuff)+r*rankOffset, send_count_mod, type, r, comm, stream));
NCCLCHECK(ncclRecv(((char*)recvbuff)+r*rankOffset, recv_count_mod, type, r, comm, stream));
}


NCCLCHECK(ncclGroupEnd());
return testSuccess;
#endif
}

struct testColl AlltoAllvTest = {
"AlltoAllV",
AlltoAllvGetCollByteCount,
AlltoAllvInitData,
AlltoAllvGetBw,
AlltoAllvRunColl
};

void AlltoAllvGetBuffSize(size_t *sendcount, size_t *recvcount, size_t count, int nranks) {
size_t paramcount, sendInplaceOffset, recvInplaceOffset;
AlltoAllvGetCollByteCount(sendcount, recvcount, &paramcount, &sendInplaceOffset, &recvInplaceOffset, count, nranks);
}

testResult_t AlltoAllvRunTest(struct threadArgs* args, int root, ncclDataType_t type, const char* typeName, ncclRedOp_t op, const char* opName) {
args->collTest = &AlltoAllvTest;
ncclDataType_t *run_types;
const char **run_typenames;
int type_count;
if ((int)type != -1) {
type_count = 1;
run_types = &type;
run_typenames = &typeName;
} else {
type_count = test_typenum;
run_types = test_types;
run_typenames = test_typenames;
}

for (int i=0; i<type_count; i++) {
TESTCHECK(TimeTest(args, run_types[i], run_typenames[i], (ncclRedOp_t)0, "none", -1));
}
return testSuccess;
}

struct testEngine AlltoAllvEngine = {
AlltoAllvGetBuffSize,
AlltoAllvRunTest
};

#pragma weak ncclTestEngine=AlltoAllvEngine
2 changes: 1 addition & 1 deletion src/broadcast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ void BroadcastGetBw(size_t count, int typesize, double sec, double* algBw, doubl
*busBw = baseBw * factor;
}

testResult_t BroadcastRunColl(void* sendbuff, void* recvbuff, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
testResult_t BroadcastRunColl(void* sendbuff, void* recvbuff, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream, struct threadArgs* args) {
int rank;
NCCLCHECK(ncclCommUserRank(comm, &rank));
#if NCCL_MAJOR >= 2 && NCCL_MINOR >= 2
Expand Down
12 changes: 9 additions & 3 deletions src/common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ int is_main_proc = 0;
thread_local int is_main_thread = 0;

// Command line parameter defaults
static char setup_file[PATH_MAX];
static int nThreads = 1;
static int nGpus = 1;
static size_t minBytes = 32*1024*1024;
Expand Down Expand Up @@ -373,11 +374,10 @@ testResult_t startColl(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t
NCCLCHECK(ncclRedOpCreatePreMulSum(&op, &u64, type, ncclScalarHostImmediate, args->comms[i]));
}
#endif

TESTCHECK(args->collTest->runColl(
(void*)(in_place ? recvBuff + args->sendInplaceOffset*rank : sendBuff),
(void*)(in_place ? recvBuff + args->recvInplaceOffset*rank : recvBuff),
count, type, op, root, args->comms[i], args->streams[i]));
count, type, op, root, args->comms[i], args->streams[i], args));

#if NCCL_VERSION_CODE >= NCCL_VERSION(2,11,0)
if(opIndex >= ncclNumOps) {
Expand Down Expand Up @@ -685,6 +685,7 @@ int main(int argc, char* argv[]) {
double parsed;
int longindex;
static struct option longopts[] = {
{"setup_file",optional_argument, 0, 's'},
{"nthreads", required_argument, 0, 't'},
{"ngpus", required_argument, 0, 'g'},
{"minbytes", required_argument, 0, 'b'},
Expand All @@ -711,12 +712,15 @@ int main(int argc, char* argv[]) {

while(1) {
int c;
c = getopt_long(argc, argv, "t:g:b:e:i:f:n:m:w:p:c:o:d:r:z:y:T:hG:C:a:", longopts, &longindex);
c = getopt_long(argc, argv, "s:t:g:b:e:i:f:n:m:w:p:c:o:d:r:z:y:T:hG:C:a:", longopts, &longindex);

if (c == -1)
break;

switch(c) {
case 's':
strcpy(setup_file,optarg);
break;
case 't':
nThreads = strtol(optarg, NULL, 0);
break;
Expand Down Expand Up @@ -983,6 +987,8 @@ testResult_t run() {
memset(threads, 0, sizeof(struct testThread)*nThreads);

for (int t=nThreads-1; t>=0; t--) {
strcpy(threads[t].args.setup_file, setup_file);

threads[t].args.minbytes=minBytes;
threads[t].args.maxbytes=maxBytes;
threads[t].args.stepbytes=stepBytes;
Expand Down
4 changes: 3 additions & 1 deletion src/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ struct testColl {
ncclRedOp_t op, int root, int rep, int in_place);
void (*getBw)(size_t count, int typesize, double sec, double* algBw, double* busBw, int nranks);
testResult_t (*runColl)(void* sendbuff, void* recvbuff, size_t count, ncclDataType_t type,
ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream);
ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream, struct threadArgs* args);
};
extern struct testColl allReduceTest;
extern struct testColl allGatherTest;
Expand All @@ -110,6 +110,8 @@ struct testEngine {
extern struct testEngine ncclTestEngine;

struct threadArgs {
char setup_file[PATH_MAX];

size_t nbytes;
size_t minbytes;
size_t maxbytes;
Expand Down
2 changes: 1 addition & 1 deletion src/gather.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ void GatherGetBw(size_t count, int typesize, double sec, double* algBw, double*
*busBw = baseBw * factor;
}

testResult_t GatherRunColl(void* sendbuff, void* recvbuff, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
testResult_t GatherRunColl(void* sendbuff, void* recvbuff, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream, struct threadArgs* args) {
int nRanks;
NCCLCHECK(ncclCommCount(comm, &nRanks));
int rank;
Expand Down
2 changes: 1 addition & 1 deletion src/hypercube.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void HyperCubeGetBw(size_t count, int typesize, double sec, double* algBw, doubl
*busBw = baseBw * factor;
}

testResult_t HyperCubeRunColl(void* sendbuff, void* recvbuff, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
testResult_t HyperCubeRunColl(void* sendbuff, void* recvbuff, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream, struct threadArgs* args) {
char* sbuff = (char*)sendbuff;
char* rbuff = (char*)recvbuff;
int nRanks;
Expand Down
2 changes: 1 addition & 1 deletion src/reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ void ReduceGetBw(size_t count, int typesize, double sec, double* algBw, double*
*busBw = baseBw;
}

testResult_t ReduceRunColl(void* sendbuff, void* recvbuff, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
testResult_t ReduceRunColl(void* sendbuff, void* recvbuff, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream, struct threadArgs* args) {
NCCLCHECK(ncclReduce(sendbuff, recvbuff, count, type, op, root, comm, stream));
return testSuccess;
}
Expand Down
2 changes: 1 addition & 1 deletion src/reduce_scatter.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ void ReduceScatterGetBw(size_t count, int typesize, double sec, double* algBw, d
*busBw = baseBw * factor;
}

testResult_t ReduceScatterRunColl(void* sendbuff, void* recvbuff, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
testResult_t ReduceScatterRunColl(void* sendbuff, void* recvbuff, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream, struct threadArgs* args) {
NCCLCHECK(ncclReduceScatter(sendbuff, recvbuff, count, type, op, comm, stream));
return testSuccess;
}
Expand Down
2 changes: 1 addition & 1 deletion src/scatter.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ void ScatterGetBw(size_t count, int typesize, double sec, double* algBw, double*
*busBw = baseBw * factor;
}

testResult_t ScatterRunColl(void* sendbuff, void* recvbuff, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
testResult_t ScatterRunColl(void* sendbuff, void* recvbuff, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream, struct threadArgs* args) {
int nRanks;
NCCLCHECK(ncclCommCount(comm, &nRanks));
int rank;
Expand Down
2 changes: 1 addition & 1 deletion src/sendrecv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ void SendRecvGetBw(size_t count, int typesize, double sec, double* algBw, double
*busBw = baseBw * factor;
}

testResult_t SendRecvRunColl(void* sendbuff, void* recvbuff, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
testResult_t SendRecvRunColl(void* sendbuff, void* recvbuff, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream, struct threadArgs* args) {
int nRanks;
NCCLCHECK(ncclCommCount(comm, &nRanks));
int rank;
Expand Down