diff --git a/deepmd/env.py b/deepmd/env.py index aaa148b357..4868eb8c3b 100644 --- a/deepmd/env.py +++ b/deepmd/env.py @@ -119,6 +119,7 @@ def get_tf_session_config() -> Any: set_tf_default_nthreads() intra, inter = get_tf_default_nthreads() config = tf.ConfigProto( + gpu_options=tf.GPUOptions(allow_growth=True), intra_op_parallelism_threads=intra, inter_op_parallelism_threads=inter ) return config