diff --git a/source/op/descrpt_se_a_multi_device.cc b/source/op/descrpt_se_a_multi_device.cc index 141b2d89bc..40f2c92eb0 100644 --- a/source/op/descrpt_se_a_multi_device.cc +++ b/source/op/descrpt_se_a_multi_device.cc @@ -237,32 +237,23 @@ class DescrptSeAOp : public OpKernel { int *mesh_host = new int[size], *ilist_host = NULL, *jrange_host = NULL, *jlist_host = NULL; cudaErrcheck(cudaMemcpy(mesh_host, mesh, sizeof(int) * size, cudaMemcpyDeviceToHost)); memcpy (&ilist_host, 4 + mesh_host, sizeof(int *)); - memcpy (&jrange_host, 8 + mesh_host, sizeof(int *)); - memcpy (&jlist_host, 12 + mesh_host, sizeof(int *)); + memcpy (&jrange_host, 8 + mesh_host, sizeof(int *)); + memcpy (&jlist_host, 12 + mesh_host, sizeof(int *)); int const ago = mesh_host[0]; - if (!init) { - ilist_size = (int)(mesh_host[1] * 1.2); - jrange_size = (int)(mesh_host[2] * 1.2); - jlist_size = (int)(mesh_host[3] * 1.2); - cudaErrcheck(cudaMalloc((void **)&ilist, sizeof(int) * ilist_size)); - cudaErrcheck(cudaMalloc((void **)&jrange, sizeof(int) * jrange_size)); - cudaErrcheck(cudaMalloc((void **)&jlist, sizeof(int) * jlist_size)); - init = true; - } - if (ago == 0) { + if (!init || ago == 0) { if (ilist_size < mesh_host[1]) { ilist_size = (int)(mesh_host[1] * 1.2); - cudaErrcheck(cudaFree(ilist)); + if (ilist != NULL) {cudaErrcheck(cudaFree(ilist));} cudaErrcheck(cudaMalloc((void **)&ilist, sizeof(int) * ilist_size)); } if (jrange_size < mesh_host[2]) { jrange_size = (int)(mesh_host[2] * 1.2); - cudaErrcheck(cudaFree(jrange)); + if (jrange != NULL) {cudaErrcheck(cudaFree(jrange));} cudaErrcheck(cudaMalloc((void **)&jrange,sizeof(int) * jrange_size)); } if (jlist_size < mesh_host[3]) { jlist_size = (int)(mesh_host[3] * 1.2); - cudaErrcheck(cudaFree(jlist)); + if (jlist != NULL) {cudaErrcheck(cudaFree(jlist));} cudaErrcheck(cudaMalloc((void **)&jlist, sizeof(int) * jlist_size)); } cudaErrcheck(cudaMemcpy(ilist, ilist_host, sizeof(int) * mesh_host[1], cudaMemcpyHostToDevice)); @@ -284,6 +275,7 @@ class DescrptSeAOp : public OpKernel { max_nbor_size = 4096; } } + init = true; delete [] mesh_host; } }; diff --git a/source/op/descrpt_se_r_multi_device.cc b/source/op/descrpt_se_r_multi_device.cc index c355e34f12..81d2603c79 100644 --- a/source/op/descrpt_se_r_multi_device.cc +++ b/source/op/descrpt_se_r_multi_device.cc @@ -226,32 +226,23 @@ class DescrptSeROp : public OpKernel { int *mesh_host = new int[size], *ilist_host = NULL, *jrange_host = NULL, *jlist_host = NULL; cudaErrcheck(cudaMemcpy(mesh_host, mesh, sizeof(int) * size, cudaMemcpyDeviceToHost)); memcpy (&ilist_host, 4 + mesh_host, sizeof(int *)); - memcpy (&jrange_host, 8 + mesh_host, sizeof(int *)); - memcpy (&jlist_host, 12 + mesh_host, sizeof(int *)); + memcpy (&jrange_host, 8 + mesh_host, sizeof(int *)); + memcpy (&jlist_host, 12 + mesh_host, sizeof(int *)); int const ago = mesh_host[0]; - if (!init) { - ilist_size = (int)(mesh_host[1] * 1.2); - jrange_size = (int)(mesh_host[2] * 1.2); - jlist_size = (int)(mesh_host[3] * 1.2); - cudaErrcheck(cudaMalloc((void **)&ilist, sizeof(int) * ilist_size)); - cudaErrcheck(cudaMalloc((void **)&jrange, sizeof(int) * jrange_size)); - cudaErrcheck(cudaMalloc((void **)&jlist, sizeof(int) * jlist_size)); - init = true; - } - if (ago == 0) { + if (!init || ago == 0) { if (ilist_size < mesh_host[1]) { ilist_size = (int)(mesh_host[1] * 1.2); - cudaErrcheck(cudaFree(ilist)); + if (ilist != NULL) {cudaErrcheck(cudaFree(ilist));} cudaErrcheck(cudaMalloc((void **)&ilist, sizeof(int) * ilist_size)); } if (jrange_size < mesh_host[2]) { jrange_size = (int)(mesh_host[2] * 1.2); - cudaErrcheck(cudaFree(jrange)); + if (jrange != NULL) {cudaErrcheck(cudaFree(jrange));} cudaErrcheck(cudaMalloc((void **)&jrange,sizeof(int) * jrange_size)); } if (jlist_size < mesh_host[3]) { jlist_size = (int)(mesh_host[3] * 1.2); - cudaErrcheck(cudaFree(jlist)); + if (jlist != NULL) {cudaErrcheck(cudaFree(jlist));} cudaErrcheck(cudaMalloc((void **)&jlist, sizeof(int) * jlist_size)); } cudaErrcheck(cudaMemcpy(ilist, ilist_host, sizeof(int) * mesh_host[1], cudaMemcpyHostToDevice)); @@ -273,6 +264,7 @@ class DescrptSeROp : public OpKernel { max_nbor_size = 4096; } } + init = true; delete [] mesh_host; }