@@ -196,6 +196,7 @@ RnnPrimitive GetRnnFwdPrim(const DNNLRnnLayerParam& layer_param,
196196 auto src_state_desc = memory::desc (layer_param.state_dims , data_type, tag::ldnc);
197197 auto src_cell_desc = memory::desc (layer_param.cell_dims , data_type, tag::ldnc);
198198 auto weight_peep_desc = memory::desc ();
199+ // clang-format off
199200 auto weight_proj_desc = layer_param.proj_size > 0 ?
200201 memory::desc (layer_param.weight_proj_dims , weight_type, tag::any) :
201202 memory::desc ();
@@ -205,6 +206,7 @@ RnnPrimitive GetRnnFwdPrim(const DNNLRnnLayerParam& layer_param,
205206 auto dst_cell_desc = layer_param.state_outputs ?
206207 memory::desc (layer_param.cell_dims , data_type, tag::ldnc) :
207208 memory::desc ();
209+ // clang-format on
208210
209211 auto fwd = RnnPrimitive ();
210212 switch (mode) {
@@ -265,19 +267,23 @@ RnnBwdPrimitive GetRnnBwdPrim(const DNNLRnnForwardTraining& fwd,
265267 memory::data_type data_type = get_dnnl_type (data.dtype ());
266268 memory::data_type weight_type = get_dnnl_type (params.dtype ());
267269 const prop_kind prop = prop_kind::backward;
270+ // clang-format off
268271 rnn_direction dnnl_rnn_direction = layer_param.bidirectional ?
269272 rnn_direction::bidirectional_concat :
270273 rnn_direction::unidirectional;
274+ // clang-format on
271275
272276 auto src_layer_desc = memory::desc (layer_param.src_dims , data_type, tag::tnc);
273277 auto weight_layer_desc = memory::desc (layer_param.weight_layer_dims , weight_type, tag::any);
274278 auto weight_iter_desc = memory::desc (layer_param.weight_iter_dims , weight_type, tag::any);
275279 auto bias_desc = memory::desc (layer_param.bias_dims , data_type, tag::ldgo);
276280 auto dst_layer_desc = memory::desc (layer_param.dst_dims , data_type, tag::tnc);
277281 auto src_state_desc = memory::desc (layer_param.state_dims , data_type, tag::ldnc);
282+ // clang-format off
278283 auto dst_state_desc = layer_param.state_outputs ?
279284 memory::desc (layer_param.state_dims , data_type, tag::ldnc) :
280285 memory::desc ();
286+ // clang-format on
281287
282288 const void * fwd_pd = fwd.GetPrimDesc ();
283289 auto bwd = RnnBwdPrimitive ();
@@ -1126,9 +1132,11 @@ void DNNLRnnOp::Forward(const OpContext& ctx,
11261132 const int seq_length = default_param.seq_length_ ;
11271133 const int batch_size = default_param.batch_size_ ;
11281134 const int state_size = default_param.state_size ;
1135+ // clang-format off
11291136 const int iter_size = default_param.projection_size .has_value () ?
11301137 default_param.projection_size .value () :
11311138 default_param.state_size ;
1139+ // clang-format on
11321140 const int directions = default_param.bidirectional ? 2 : 1 ;
11331141 dnnl::memory::desc dst_desc ({seq_length, batch_size, directions * iter_size},
11341142 get_dnnl_type (data_dtype),
0 commit comments