diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index 95a83a905908..dbf8537e0dad 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -101,29 +101,64 @@ bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs } // Calculate shape - std::vector oshape(first->shape.begin(), first->shape.end()); - int data_length = static_cast(tensor_tuple->fields.size()); + std::vector oshape(ndim); + const size_t data_length = tensor_tuple->fields.size(); + + // Accumulate the concat axis output dim or decide if this is dynamic concat + bool is_dynamic_concat = false; + std::vector input_tensors; + IndexExpr concat_output_dim = first->shape[axis]; + for (size_t i = 0; i < data_length; ++i) { + const auto& e = Downcast(tensor_tuple->fields[i]); + input_tensors.push_back(e); + if (e->shape[axis].as()) { + is_dynamic_concat = true; + concat_output_dim = Any(); + } else if (i > 0 && !is_dynamic_concat) { + // accumulate axis dimension + concat_output_dim += e->shape[axis]; + } + } + + oshape[axis] = concat_output_dim; + for (int i = 0; i < ndim; ++i) { + if (i == axis) { + // The concat axis is already handled above. + // The rest of the body sets the output shape for non-concat axes + continue; + } std::vector non_any; - for (int j = 0; j < data_length; ++j) { - const auto& e = Downcast(tensor_tuple->fields[j]); + for (size_t j = 0; j < data_length; ++j) { + const auto& e = input_tensors[j]; if (!e->shape[i].as()) { non_any.push_back(e->shape[i]); - // accumulate axis dimension - if (j > 0 && i == axis && !oshape[i].as()) { - oshape[i] += e->shape[i]; - } } } - int non_any_size = static_cast(non_any.size()); - if (non_any_size != data_length) oshape[i] = Any(); - if (i != axis) { - for (int k = 1; k < non_any_size; k++) { - if (reporter->AssertEQ(non_any[0], non_any[k])) continue; - throw Error( - "relay.concatenate requires all tensors have the same shape " - "on non-concatenating axes"); - } + size_t non_any_size = non_any.size(); + for (size_t k = 1; k < non_any_size; k++) { + if (reporter->AssertEQ(non_any[0], non_any[k])) continue; + throw Error( + "relay.concatenate requires all tensors have the same shape " + "on non-concatenating axes"); + } + + if (non_any_size == data_length) { + // All static case + oshape[i] = non_any[0]; + } else if (non_any_size > 0 && is_dynamic_concat) { + // For non-concat axes, we want to enforce static shape constraint. + // However, if the concat axis is static, the output shape would become static while + // the input could be partially static/dynamic. To prevent runtime segfaults due to the lack + // of runtime input shape checking for such cases, static shape constraint is only enforced + // when the output concat axis is dynamic. + // + // Examples (both concat on the first axis): + // * [(?, 3), (?, ?)] -> (?, 3) + // * [(1, 3), (1, ?)] -> (2, ?) + oshape[i] = non_any[0]; + } else { + oshape[i] = Any(); } } diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 9d05631a753a..b75cc5f5e750 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -208,6 +208,27 @@ def test_any_concat(): ref = np.concatenate(x_np, axis=0) check_result(x_np, mod, ref) + def test_oshape(in_vars, axis, oshape): + z = relay.op.concatenate(in_vars, axis=axis) + mod = tvm.IRModule() + mod["main"] = relay.Function(in_vars, z) + typed_mod = relay.transform.InferType()(mod) + assert typed_mod["main"].body.checked_type == relay.TensorType(oshape, dtype="float32") + + x = [relay.var("x", shape=(relay.Any(), 3), dtype="float32") for _ in range(3)] + x.append(relay.var("x", shape=(relay.Any(), relay.Any()), dtype="float32")) + + test_oshape(x, 0, (relay.Any(), 3)) + test_oshape(x, 1, (relay.Any(), relay.Any())) + + # [(1, 3), (1, ?)] -> (2, ?) + x = [ + relay.var("x", shape=(1, 3), dtype="float32"), + relay.var("x", shape=(1, relay.Any()), dtype="float32"), + ] + test_oshape(x, 0, (2, relay.Any())) + test_oshape(x, 1, (1, relay.Any())) + def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape, variable_newshape=False): x = relay.var("x", shape=x_shape, dtype="float32")