Skip to content

[Prelude][Relay] Add utility function to extract tensor from prelude tensor_t #4291

@wweic

Description

@wweic

tensor_t in prelude is a type that wraps variable rank tensor with variable shape. Like, tensor_array_concat will return object of type tensor_t.

But for most of the other TF operators, they expect tensors with fixed rank. For example sum(tensor_array_concat(tensor_array)), it won't work right now since sum can't accept tensor_t returned by tensor_array_concat. We need functions to convert tensor_t to fixed rank tensors. They are:

let get_data_tensor1: tensor_t -> Tensor[(Any), dtype]
let get_data_tensor2: tensor_t -> Tensor[(Any, Any), dtype]
let get_data_tensor3: tensor_t -> Tensor[(Any, Any, Any), dtype]
let get_data_tensor4: tensor_t -> Tensor[(Any, Any, Any, Any), dtype]
let get_data_tensor5: tensor_t -> Tensor[(Any, Any, Any, Any, Any), dtype]
let get_data_tensor6: tensor_t -> Tensor[(Any, Any, Any, Any, Any, Any), dtype]

The reason we need 6 functions because relay can not represent a function has the following type:

let get_data_tensor: tensor_t -> Tensor[(Any), dtype] or Tensor[(Any, Any), dtype] or Tensor[(Any, Any, Any), dtype]

We also need a function get_tensor_rank: tensor_t -> Tensor[(), int32] so we can dispatch to the right get_data_tensor function given the rank of tensor_t.

Since we start to look into models with tensor array ops, we will have to implement these functions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions