diff --git a/source/lib/src/gpu/tabulate.cu b/source/lib/src/gpu/tabulate.cu index 9f924efd9b..09d02bdf2c 100644 --- a/source/lib/src/gpu/tabulate.cu +++ b/source/lib/src/gpu/tabulate.cu @@ -411,6 +411,26 @@ __global__ void tabulate_fusion_se_a_grad_grad_fifth_order_polynomial( res_grad += res_grad * t; } + /* + * `dz_dy`(or `iteratorC`) represents the derivative of the variable `out` + * in the function `tabulate_fusion_se_a_fifth_order_polynomial`. + * + * The expression `em[em_index] * res_grad * dz_xx + dz_dy_dem[em_index] * + * res` utilizes the product rule of derivatives: `(f * g)' = f' * g + f * + * g'`. + * + * This expression can be alternatively expressed as: + * `dz_dy_dem[em_index] * res + em[em_index] * (res_grad * dz_xx)`. + * Note that we can refer to `dz_dy_dem` as `em'` + * + * Therefore, we can rewrite this expression as: `em' * res + em * res'`, + * where `em'` is the derivative of `em` and `res'` is the derivative of + * `res`. Additionally, `res'` can be further represented as: `res_grad * + * dz_xx`. + * + * If `enable_se_atten` is true, `res` will be `res * t + res`, and `res'` + * will become `(res_grad * t + res_grad) * dz_xx`. + */ for (int kk = 0; kk < MTILE; kk++) { int em_index = block_idx * nnei * MTILE + ii * MTILE + kk; iteratorC[kk * last_layer_size + thread_idx] += diff --git a/source/lib/src/tabulate.cc b/source/lib/src/tabulate.cc index 9b659269e0..1f49cf0daa 100644 --- a/source/lib/src/tabulate.cc +++ b/source/lib/src/tabulate.cc @@ -306,6 +306,27 @@ void deepmd::tabulate_fusion_se_a_grad_grad_cpu(FPTYPE* dz_dy, var += var * t; var_grad += var_grad * t; } + + /* + * `dz_dy` represents the derivative of the variable `out` in the + * function `deepmd::tabulate_fusion_se_a_cpu`. + * + * The expression `var * hh[0] + dz_xx * var_grad * ll[0]` utilizes the + * product rule of derivatives: `(f * g)' = f' * g + f * g'`. + * + * This expression can be alternatively expressed as: + * `hh[0] * var + ll[0] * (dz_xx * var_grad)`. + * Note that `hh[0]` is one element of `em`, and `ll[0]` is one element + * of `dz_dy_dem` which is `em'`. + * + * Therefore, we can rewrite this expression as: `em' * var + em * + * var'`, where `em'` is the derivative of `em` and `var'` is the + * derivative of `var`. Additionally, `var'` can be further represented + * as: `var_grad * dz_xx`. + * + * If `enable_se_atten` is true, `var` will be `var * t + var`, and + * `var'` will be `(var_grad * t + var_grad) * dz_xx`. + */ if (unloop) { dz_dy[ii * last_layer_size * 4 + 0 * last_layer_size + kk] += (nnei - jj) * (var * hh[0] + dz_xx * var_grad * ll[0]);