From c27df8b74b84a5bb05a3c81d4c6a471f00d459cb Mon Sep 17 00:00:00 2001 From: Duo Date: Tue, 11 May 2021 17:22:49 +0800 Subject: [PATCH] Fix bug in tf allocate_temp --- source/op/prod_env_mat_multi_device.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/source/op/prod_env_mat_multi_device.cc b/source/op/prod_env_mat_multi_device.cc index be812dbc9e..1bdd5553f4 100644 --- a/source/op/prod_env_mat_multi_device.cc +++ b/source/op/prod_env_mat_multi_device.cc @@ -803,7 +803,7 @@ _norm_copy_coord_gpu( // Tensor FPTYPE_temp; TensorShape FPTYPE_shape; FPTYPE_shape.AddDim(nall*3); - context->allocate_temp(DT_DOUBLE, FPTYPE_shape, tensor_list); + context->allocate_temp(DataTypeToEnum::value, FPTYPE_shape, tensor_list); FPTYPE * tmp_coord = (*tensor_list).flat().data(); cudaErrcheck(cudaMemcpy(tmp_coord, coord, sizeof(FPTYPE) * nall * 3, cudaMemcpyDeviceToDevice)); @@ -819,7 +819,7 @@ _norm_copy_coord_gpu( //Tensor double_temp; TensorShape double_shape; double_shape.AddDim(18); - context->allocate_temp(DT_DOUBLE, double_shape, tensor_list+1); + context->allocate_temp(DataTypeToEnum::value, double_shape, tensor_list+1); //Tensor int_temp; TensorShape int_shape; int_shape.AddDim(23+nloc*3+loc_cellnum+total_cellnum*3+total_cellnum*3+loc_cellnum+1+total_cellnum+1+nloc); @@ -840,7 +840,7 @@ _norm_copy_coord_gpu( //Tensor cpy_temp; TensorShape cpy_shape; cpy_shape.AddDim(mem_cpy*3); - context->allocate_temp(DT_DOUBLE, cpy_shape, tensor_list+3); + context->allocate_temp(DataTypeToEnum::value, cpy_shape, tensor_list+3); //Tensor t_temp; TensorShape t_shape; t_shape.AddDim(mem_cpy*2);