@@ -1390,6 +1390,22 @@ def convert_one_hot_v2(g, op, block):
13901390 g .add_node (op .output ("Out" )[0 ], out )
13911391
13921392
1393+ def convert_p_norm (g , op , blcok ):
1394+ """Operator converter for p_norm."""
1395+
1396+ x = g .get_node (op .input ("X" )[0 ])
1397+ axis = op .attr ("axis" )
1398+ p = op .attr ("porder" )
1399+ keepdim = op .attr ("keepdim" )
1400+ p_node = _expr .const (p , dtype = "float32" )
1401+ abs_node = _op .abs (x )
1402+ pow_node = _op .power (abs_node , p_node )
1403+ reduce_sum = _op .sum (pow_node , axis = [axis ], keepdims = keepdim )
1404+ p_node1 = _expr .const (1.0 / p , dtype = "float32" )
1405+ out = _op .power (reduce_sum , p_node1 )
1406+ g .add_node (op .output ("Out" )[0 ], out )
1407+
1408+
13931409def convert_padding (g , op , block ):
13941410 """Operator converter for padding."""
13951411
@@ -1638,6 +1654,30 @@ def convert_reshape(g, op, block):
16381654 g .add_node (op .output ("Out" )[0 ], out )
16391655
16401656
1657+ def convert_roi_align (g , op , block ):
1658+ """Operator converter for roi_align."""
1659+
1660+ rois = g .get_node (op .input ("ROIs" )[0 ])
1661+ spatial_scale = op .attr ("spatial_scale" )
1662+ if op .attr ("aligned" ):
1663+ offset = _expr .const (0.5 , dtype = "float32" )
1664+ roi_offset = _op .divide (offset , _expr .const (spatial_scale , dtype = "float32" ))
1665+ rois = _op .subtract (rois , roi_offset )
1666+ num_rois = infer_shape (rois )[0 ]
1667+ zero_node = _expr .const (0 , dtype = "int32" )
1668+ batch_index = _op .full (zero_node , [num_rois , 1 ], dtype = "float32" )
1669+ rois = _op .concatenate ([batch_index , rois ], axis = 1 )
1670+ out = _op .vision .roi_align (
1671+ g .get_node (op .input ("X" )[0 ]),
1672+ rois ,
1673+ pooled_size = [op .attr ("pooled_height" ), op .attr ("pooled_width" )],
1674+ spatial_scale = spatial_scale ,
1675+ sample_ratio = op .attr ("sampling_ratio" ),
1676+ mode = "avg" ,
1677+ )
1678+ g .add_node (op .output ("Out" )[0 ], out )
1679+
1680+
16411681def convert_rnn (g , op , block ):
16421682 """Operator converter for rnn."""
16431683
@@ -2156,6 +2196,45 @@ def convert_softmax(g, op, block):
21562196 g .add_node (op .output ("Out" )[0 ], out )
21572197
21582198
2199+ def convert_softmax_with_cross_entropy (g , op , block ):
2200+ """Operator converter for softmax_with_cross_entropy."""
2201+
2202+ logits = g .get_node (op .input ("Logits" )[0 ])
2203+ labels = g .get_node (op .input ("Label" )[0 ])
2204+ ignore_index = op .attr ("ignore_index" )
2205+ axis = op .attr ("axis" )
2206+ if axis < 0 :
2207+ axis = len (infer_shape (logits )) + axis
2208+
2209+ softmax = _op .nn .softmax (logits , axis = axis )
2210+
2211+ g .add_node (op .output ("Softmax" )[0 ], softmax )
2212+
2213+ softmax = _op .log (softmax )
2214+ soft_label = op .attr ("soft_label" )
2215+ if soft_label :
2216+ loss = _op .sum (- labels * softmax , axis = axis )
2217+ else :
2218+ labels_one = _op .one_hot (
2219+ labels ,
2220+ on_value = _expr .const (1.0 , dtype = "float32" ),
2221+ off_value = _expr .const (0.0 , dtype = "float32" ),
2222+ depth = infer_shape (logits )[axis ],
2223+ axis = axis + 1 ,
2224+ dtype = "float32" ,
2225+ )
2226+ labels_one = _op .squeeze (labels_one , axis = axis )
2227+ loss = _op .sum (- labels_one * softmax , axis = axis )
2228+ loss = _op .expand_dims (loss , axis = axis )
2229+ if ignore_index != - 100 : # noly when soft_label is False
2230+ assert not soft_label , "soft_label and ignore_index cannot be set at the same time."
2231+ ignore_mask = _op .not_equal (labels , _expr .const (ignore_index , dtype = "int64" ))
2232+ ignore_mask = _op .cast (ignore_mask , "float32" )
2233+ loss = _op .multiply (loss , ignore_mask )
2234+
2235+ g .add_node (op .output ("Loss" )[0 ], loss )
2236+
2237+
21592238def convert_softplus (g , op , block ):
21602239 """Operator converter for softplus."""
21612240
@@ -2549,6 +2628,7 @@ def convert_where_index(g, op, block):
25492628 "norm" : convert_norm ,
25502629 "not_equal" : convert_elementwise_op ,
25512630 "one_hot_v2" : convert_one_hot_v2 ,
2631+ "p_norm" : convert_p_norm ,
25522632 "pad1d" : convert_padding ,
25532633 "pad2d" : convert_padding ,
25542634 "pad3d" : convert_padding ,
@@ -2561,6 +2641,7 @@ def convert_where_index(g, op, block):
25612641 "relu6" : convert_relu6 ,
25622642 "reshape2" : convert_reshape ,
25632643 "round" : convert_unary_op ,
2644+ "roi_align" : convert_roi_align ,
25642645 "reciprocal" : convert_reciprocal ,
25652646 "reduce_all" : convert_reduce ,
25662647 "reduce_any" : convert_reduce ,
@@ -2584,6 +2665,7 @@ def convert_where_index(g, op, block):
25842665 "size" : convert_size ,
25852666 "slice" : convert_slice ,
25862667 "softmax" : convert_softmax ,
2668+ "softmax_with_cross_entropy" : convert_softmax_with_cross_entropy ,
25872669 "softplus" : convert_softplus ,
25882670 "softsign" : convert_softsign ,
25892671 "softshrink" : convert_softshrink ,
0 commit comments