Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions integration_tests/test_gruntz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from lpython import S
from sympy import Symbol

def mrv(e: S, x: S) -> tuple[dict[S, S], S]:
if not e.has(x):
empty_dict : dict[S, S] = {}
return empty_dict, x
else:
raise

def test_mrv():
x: S = Symbol("x")
y: S = Symbol("y")
ans: tuple[dict[S, S], S] = mrv(y, x)

test_mrv()
17 changes: 10 additions & 7 deletions src/libasr/pass/pass_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ namespace LCompilers {
static inline bool is_aggregate_or_array_type(ASR::expr_t* var) {
return (ASR::is_a<ASR::Struct_t>(*ASRUtils::expr_type(var)) ||
ASRUtils::is_array(ASRUtils::expr_type(var)) ||
ASR::is_a<ASR::SymbolicExpression_t>(*ASRUtils::expr_type(var)));
ASR::is_a<ASR::SymbolicExpression_t>(*ASRUtils::expr_type(var)) ||
ASR::is_a<ASR::Tuple_t>(*ASRUtils::expr_type(var)));
}

template <class Struct>
Expand Down Expand Up @@ -776,7 +777,7 @@ namespace LCompilers {
}

static inline void handle_fn_return_var(Allocator &al, ASR::Function_t *x,
bool (*is_array_or_struct_or_symbolic)(ASR::expr_t*)) {
bool (*is_array_or_struct_or_symbolic_or_tuple)(ASR::expr_t*)) {
if (ASRUtils::get_FunctionType(x)->m_abi == ASR::abiType::BindPython) {
return;
}
Expand All @@ -788,7 +789,7 @@ namespace LCompilers {
* in avoiding deep copies and the destination memory directly gets
* filled inside the function.
*/
if( is_array_or_struct_or_symbolic(x->m_return_var)) {
if( is_array_or_struct_or_symbolic_or_tuple(x->m_return_var)) {
for( auto& s_item: x->m_symtab->get_scope() ) {
ASR::symbol_t* curr_sym = s_item.second;
if( curr_sym->type == ASR::symbolType::Variable ) {
Expand Down Expand Up @@ -824,9 +825,11 @@ namespace LCompilers {
s_func_type->m_return_var_type = nullptr;

Vec<ASR::stmt_t*> func_body;
func_body.reserve(al, x->n_body - 1);
for (size_t i=0; i< x->n_body - 1; i++) {
func_body.push_back(al, x->m_body[i]);
func_body.reserve(al, x->n_body);
for (size_t i=0; i< x->n_body; i++) {
if (!ASR::is_a<ASR::Return_t>(*x->m_body[i])) {
func_body.push_back(al, x->m_body[i]);
}
}
x->m_body = func_body.p;
x->n_body = func_body.n;
Expand All @@ -835,7 +838,7 @@ namespace LCompilers {
for (auto &item : x->m_symtab->get_scope()) {
if (ASR::is_a<ASR::Function_t>(*item.second)) {
handle_fn_return_var(al, ASR::down_cast<ASR::Function_t>(
item.second), is_array_or_struct_or_symbolic);
item.second), is_array_or_struct_or_symbolic_or_tuple);
}
}
}
Expand Down
130 changes: 128 additions & 2 deletions src/libasr/pass/replace_symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,35 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi

ASR::ttype_t* f_signature= xx.m_function_signature;
ASR::FunctionType_t *f_type = ASR::down_cast<ASR::FunctionType_t>(f_signature);
ASR::ttype_t *type1 = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc));
ASR::ttype_t *CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc));
for (size_t i = 0; i < f_type->n_arg_types; ++i) {
if (f_type->m_arg_types[i]->type == ASR::ttypeType::SymbolicExpression) {
f_type->m_arg_types[i] = type1;
f_type->m_arg_types[i] = CPtr_type;
} else if (f_type->m_arg_types[i]->type == ASR::ttypeType::Tuple) {
Vec<ASR::ttype_t*> tuple_type_vec;
ASR::Tuple_t* tuple = ASR::down_cast<ASR::Tuple_t>(f_type->m_arg_types[i]);
tuple_type_vec.reserve(al, tuple->n_type);
for( size_t i = 0; i < tuple->n_type; i++ ) {
if (tuple->m_type[i]->type == ASR::ttypeType::SymbolicExpression) {
tuple_type_vec.push_back(al, CPtr_type);
} else if (tuple->m_type[i]->type == ASR::ttypeType::Dict) {
ASR::Dict_t *dict = ASR::down_cast<ASR::Dict_t>(tuple->m_type[i]);
ASR::ttype_t *key_type = dict->m_key_type;
ASR::ttype_t *value_type = dict->m_value_type;
if (key_type->type == ASR::ttypeType::SymbolicExpression) {
key_type = CPtr_type;
}
if (value_type->type == ASR::ttypeType::SymbolicExpression) {
value_type = CPtr_type;
}
ASR::ttype_t* dict_type = ASRUtils::TYPE(ASR::make_Dict_t(al, xx.base.base.loc, key_type, value_type));
tuple_type_vec.push_back(al, dict_type);
} else {
tuple_type_vec.push_back(al, tuple->m_type[i]);
}
}
ASR::ttype_t* tuple_type = ASRUtils::TYPE(ASR::make_Tuple_t(al, xx.base.base.loc, tuple_type_vec.p, tuple_type_vec.n));
f_type->m_arg_types[i] = tuple_type;
}
}

Expand Down Expand Up @@ -256,6 +281,45 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
ASR::ttype_t* list_type = ASRUtils::TYPE(ASR::make_List_t(al, xx.base.base.loc, CPtr_type));
xx.m_type = list_type;
}
} else if (xx.m_type->type == ASR::ttypeType::Tuple) {
Vec<ASR::ttype_t*> tuple_type_vec;
ASR::Tuple_t* tuple = ASR::down_cast<ASR::Tuple_t>(xx.m_type);
tuple_type_vec.reserve(al, tuple->n_type);
ASR::ttype_t *CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc));
for( size_t i = 0; i < tuple->n_type; i++ ) {
if (tuple->m_type[i]->type == ASR::ttypeType::SymbolicExpression) {
tuple_type_vec.push_back(al, CPtr_type);
} else if (tuple->m_type[i]->type == ASR::ttypeType::Dict) {
ASR::Dict_t *dict = ASR::down_cast<ASR::Dict_t>(tuple->m_type[i]);
ASR::ttype_t *key_type = dict->m_key_type;
ASR::ttype_t *value_type = dict->m_value_type;
if (key_type->type == ASR::ttypeType::SymbolicExpression) {
key_type = CPtr_type;
}
if (value_type->type == ASR::ttypeType::SymbolicExpression) {
value_type = CPtr_type;
}
ASR::ttype_t* dict_type = ASRUtils::TYPE(ASR::make_Dict_t(al, xx.base.base.loc, key_type, value_type));
tuple_type_vec.push_back(al, dict_type);
} else {
tuple_type_vec.push_back(al, tuple->m_type[i]);
}
}
ASR::ttype_t* tuple_type = ASRUtils::TYPE(ASR::make_Tuple_t(al, xx.base.base.loc, tuple_type_vec.p, tuple_type_vec.n));
xx.m_type = tuple_type;
} else if (xx.m_type->type == ASR::ttypeType::Dict) {
ASR::ttype_t *CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc));
ASR::Dict_t *dict = ASR::down_cast<ASR::Dict_t>(xx.m_type);
ASR::ttype_t *key_type = dict->m_key_type;
ASR::ttype_t *value_type = dict->m_value_type;
if (key_type->type == ASR::ttypeType::SymbolicExpression) {
key_type = CPtr_type;
}
if (value_type->type == ASR::ttypeType::SymbolicExpression) {
value_type = CPtr_type;
}
ASR::ttype_t* dict_type = ASRUtils::TYPE(ASR::make_Dict_t(al, xx.base.base.loc, key_type, value_type));
xx.m_type = dict_type;
}
}

Expand Down Expand Up @@ -1374,6 +1438,57 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, function_call, nullptr));
pass_result.push_back(al, stmt);
}
} else if (ASR::is_a<ASR::TupleConstant_t>(*x.m_value)) {
ASR::TupleConstant_t* tuple_constant = ASR::down_cast<ASR::TupleConstant_t>(x.m_value);
if (tuple_constant->m_type->type == ASR::ttypeType::Tuple) {
ASR::Tuple_t* tuple = ASR::down_cast<ASR::Tuple_t>(tuple_constant->m_type);
Vec<ASR::ttype_t*> tuple_type_vec;
tuple_type_vec.reserve(al, tuple->n_type);
ASR::ttype_t *CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc));
for( size_t i = 0; i < tuple->n_type; i++ ) {
if (tuple->m_type[i]->type == ASR::ttypeType::SymbolicExpression) {
tuple_type_vec.push_back(al, CPtr_type);
} else if (tuple->m_type[i]->type == ASR::ttypeType::Dict) {
ASR::Dict_t *dict = ASR::down_cast<ASR::Dict_t>(tuple->m_type[i]);
ASR::ttype_t *key_type = dict->m_key_type;
ASR::ttype_t *value_type = dict->m_value_type;
if (key_type->type == ASR::ttypeType::SymbolicExpression) {
key_type = CPtr_type;
}
if (value_type->type == ASR::ttypeType::SymbolicExpression) {
value_type = CPtr_type;
}
ASR::ttype_t* dict_type = ASRUtils::TYPE(ASR::make_Dict_t(al, x.base.base.loc, key_type, value_type));
tuple_type_vec.push_back(al, dict_type);
} else {
tuple_type_vec.push_back(al, tuple->m_type[i]);
}
}
ASR::ttype_t* tuple_type = ASRUtils::TYPE(ASR::make_Tuple_t(al, x.base.base.loc, tuple_type_vec.p, tuple_type_vec.n));
ASR::expr_t* temp_tuple_const = ASRUtils::EXPR(ASR::make_TupleConstant_t(al, x.base.base.loc, tuple_constant->m_elements,
tuple_constant->n_elements, tuple_type));
ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, temp_tuple_const, nullptr));
pass_result.push_back(al, stmt);
}
} else if (ASR::is_a<ASR::DictConstant_t>(*x.m_value)) {
ASR::DictConstant_t* dict_constant = ASR::down_cast<ASR::DictConstant_t>(x.m_value);
if (dict_constant->m_type->type == ASR::ttypeType::Dict) {
ASR::Dict_t* dict = ASR::down_cast<ASR::Dict_t>(dict_constant->m_type);
ASR::ttype_t *CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc));
ASR::ttype_t *key_type = dict->m_key_type;
ASR::ttype_t *value_type = dict->m_value_type;
if (key_type->type == ASR::ttypeType::SymbolicExpression) {
key_type = CPtr_type;
}
if (value_type->type == ASR::ttypeType::SymbolicExpression) {
value_type = CPtr_type;
}
ASR::ttype_t* dict_type = ASRUtils::TYPE(ASR::make_Dict_t(al, x.base.base.loc, key_type, value_type));
ASR::expr_t* temp_dict_const = ASRUtils::EXPR(ASR::make_DictConstant_t(al, x.base.base.loc, dict_constant->m_keys,
dict_constant->n_keys, dict_constant->m_values, dict_constant->n_values, dict_type));
ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, temp_dict_const, nullptr));
pass_result.push_back(al, stmt);
}
}
}

Expand All @@ -1388,6 +1503,17 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
ASR::expr_t* function_call = process_attributes(al, xx.base.base.loc, xx.m_test, module_scope);
xx.m_test = function_call;
}
} else if (ASR::is_a<ASR::LogicalNot_t>(*xx.m_test)) {
ASR::LogicalNot_t* logical_not = ASR::down_cast<ASR::LogicalNot_t>(xx.m_test);
if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*logical_not->m_arg)) {
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(logical_not->m_arg);
if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) {
ASR::expr_t* function_call = process_attributes(al, xx.base.base.loc, logical_not->m_arg, module_scope);
ASR::expr_t* new_logical_not = ASRUtils::EXPR(ASR::make_LogicalNot_t(al, xx.base.base.loc, function_call,
logical_not->m_type, logical_not->m_value));
xx.m_test = new_logical_not;
}
}
}
}

Expand Down