Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions ASTNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -642,13 +642,15 @@ class ASTWithBlock : public ASTBlock {

class ASTComprehension : public ASTNode {
public:
enum ComprehensionType { LISTCOMP, GENEXPR };
typedef std::list<PycRef<ASTIterBlock>> generator_t;

ASTComprehension(PycRef<ASTNode> result)
: ASTNode(NODE_COMPREHENSION), m_result(std::move(result)) { }
ASTComprehension(PycRef<ASTNode> result, ComprehensionType type = LISTCOMP)
: ASTNode(NODE_COMPREHENSION), m_result(std::move(result)), m_comptype(type) { }

PycRef<ASTNode> result() const { return m_result; }
generator_t generators() const { return m_generators; }
ComprehensionType comprehensionType() const { return m_comptype; }

void addGenerator(PycRef<ASTIterBlock> gen) {
m_generators.emplace_front(std::move(gen));
Expand All @@ -657,6 +659,7 @@ class ASTComprehension : public ASTNode {
private:
PycRef<ASTNode> m_result;
generator_t m_generators;
ComprehensionType m_comptype;

};

Expand Down
141 changes: 124 additions & 17 deletions ASTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,58 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
stack.pop();
}

// Detect inline <genexpr> call and convert to ASTComprehension
if (func.type() == ASTNode::NODE_FUNCTION
&& kwparamList.empty()
&& pparamList.size() == 1) {
PycRef<ASTNode> fun_code_node = func.cast<ASTFunction>()->code();
if (fun_code_node.type() == ASTNode::NODE_OBJECT) {
PycRef<PycObject> fun_code_obj = fun_code_node.cast<ASTObject>()->object();
if (fun_code_obj.type() == PycObject::TYPE_CODE
|| fun_code_obj.type() == PycObject::TYPE_CODE2) {
PycRef<PycCode> genexpr_code = fun_code_obj.cast<PycCode>();
bool is_genexpr = (genexpr_code->flags() & PycCode::CO_GENERATOR)
&& genexpr_code->argCount() == 1
&& genexpr_code->numLocals() >= 1
&& strcmp(genexpr_code->getLocal(0)->value(), ".0") == 0;
if (is_genexpr) {
PycRef<ASTNode> actual_iter = pparamList.front();
PycRef<ASTNode> genexpr_ast = BuildFromCode(genexpr_code, mod);
PycRef<ASTNodeList> genexpr_nodes = genexpr_ast.cast<ASTNodeList>();
if (!genexpr_nodes->nodes().empty()) {
PycRef<ASTNode> last_node = genexpr_nodes->nodes().back();
if (last_node.type() == ASTNode::NODE_RETURN) {
PycRef<ASTNode> ret_val = last_node.cast<ASTReturn>()->value();
if (ret_val.type() == ASTNode::NODE_COMPREHENSION
&& ret_val.cast<ASTComprehension>()->comprehensionType()
== ASTComprehension::GENEXPR) {
PycRef<ASTComprehension> inner = ret_val.cast<ASTComprehension>();
PycRef<ASTComprehension> result = new ASTComprehension(
inner->result(), ASTComprehension::GENEXPR);
// Rebuild generators, replacing the implicit .0 iter
// of the outermost for-clause with the actual argument.
bool outermost = true;
for (const auto& gen : inner->generators()) {
PycRef<ASTNode> gen_iter = outermost ? actual_iter : gen->iter();
PycRef<ASTIterBlock> new_gen = new ASTIterBlock(
gen->blktype(), gen->start(), gen->end(), gen_iter);
new_gen->setIndex(gen->index());
new_gen->setComprehension(gen->isComprehension());
if (gen->condition())
new_gen->setCondition(gen->condition());
result->addGenerator(new_gen);
outermost = false;
}
stack.push(result.cast<ASTNode>());
break;
}
}
}
}
}
}
}

stack.push(new ASTCall(func, pparamList, kwparamList));
}
break;
Expand Down Expand Up @@ -555,24 +607,45 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
{
PycRef<ASTNode> kw = stack.top();
stack.pop();
int kwparams = (operand & 0xFF00) >> 8;
int pparams = (operand & 0xFF);
ASTCall::kwparam_t kwparamList;
ASTCall::pparam_t pparamList;
for (int i=0; i<kwparams; i++) {
PycRef<ASTNode> val = stack.top();
stack.pop();
PycRef<ASTNode> key = stack.top();
stack.pop();
kwparamList.push_front(std::make_pair(key, val));
int kwparams, pparams;
if (mod->verCompare(3, 9) >= 0) {
if (kw.type() != ASTNode::NODE_OBJECT) {
fprintf(stderr, "Something TERRIBLE happened!!\n");
break;
}
PycRef<PycObject> kw_obj = kw.cast<ASTObject>()->object();
if (kw_obj.type() != PycObject::TYPE_TUPLE && kw_obj.type() != PycObject::TYPE_SMALL_TUPLE) {
fprintf(stderr, "Something TERRIBLE happened!!\n");
break;
}
const auto& kw_names = kw_obj.cast<PycTuple>()->values();
kwparams = static_cast<int>(kw_names.size());
pparams = operand - kwparams;
for (int i = kwparams - 1; i >= 0; i--) {
PycRef<ASTNode> val = stack.top();
stack.pop();
PycRef<ASTNode> key = new ASTObject(kw_names[i]);
kwparamList.push_front(std::make_pair(key, val));
}
} else {
kwparams = (operand & 0xFF00) >> 8;
pparams = (operand & 0xFF);
for (int i = 0; i < kwparams; i++) {
PycRef<ASTNode> val = stack.top();
stack.pop();
PycRef<ASTNode> key = stack.top();
stack.pop();
kwparamList.push_front(std::make_pair(key, val));
}
}
for (int i=0; i<pparams; i++) {
for (int i = 0; i < pparams; i++) {
pparamList.push_front(stack.top());
stack.pop();
}
PycRef<ASTNode> func = stack.top();
stack.pop();

PycRef<ASTNode> call = new ASTCall(func, pparamList, kwparamList);
call.cast<ASTCall>()->setKW(kw);
stack.push(call);
Expand All @@ -588,14 +661,14 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
int pparams = (operand & 0xFF);
ASTCall::kwparam_t kwparamList;
ASTCall::pparam_t pparamList;
for (int i=0; i<kwparams; i++) {
for (int i = 0; i < kwparams; i++) {
PycRef<ASTNode> val = stack.top();
stack.pop();
PycRef<ASTNode> key = stack.top();
stack.pop();
kwparamList.push_front(std::make_pair(key, val));
}
for (int i=0; i<pparams; i++) {
for (int i = 0; i < pparams; i++) {
pparamList.push_front(stack.top());
stack.pop();
}
Expand Down Expand Up @@ -915,7 +988,11 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
if (mod->verCompare(3, 10) >= 0)
end *= sizeof(uint16_t); // // BPO-27129
end += pos;
comprehension = strcmp(code->name()->value(), "<listcomp>") == 0;
bool is_genexpr = (code->flags() & PycCode::CO_GENERATOR)
&& code->argCount() >= 1
&& code->numLocals() >= 1
&& strcmp(code->getLocal(0)->value(), ".0") == 0;
comprehension = strcmp(code->name()->value(), "<listcomp>") == 0 || is_genexpr;
} else {
PycRef<ASTBlock> top = blocks.top();
end = top->end(); // block end position from SETUP_LOOP
Expand Down Expand Up @@ -1731,6 +1808,16 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
break;
}

// For genexpr: POP_TOP discards the sent value after YIELD_VALUE.
// Keep the ASTComprehension on the stack so JUMP_BACKWARD can find it.
if (curblock->blktype() == ASTBlock::BLK_FOR
&& curblock.cast<ASTIterBlock>()->isComprehension()
&& value.type() == ASTNode::NODE_COMPREHENSION
&& value.cast<ASTComprehension>()->comprehensionType() == ASTComprehension::GENEXPR) {
stack.push(value);
break;
}

curblock->append(value);

if (curblock->blktype() == ASTBlock::BLK_FOR
Expand Down Expand Up @@ -1835,6 +1922,19 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
{
PycRef<ASTNode> value = stack.top();
stack.pop();
// For genexpr: "return None" sits above the ASTComprehension on the stack.
// Lift the comprehension as the real return value so BuildFromCode exposes it.
if (!stack.empty() && stack.top() != nullptr
&& stack.top().type() == ASTNode::NODE_COMPREHENSION
&& stack.top().cast<ASTComprehension>()->comprehensionType() == ASTComprehension::GENEXPR) {
bool value_is_none = (value == nullptr)
|| (value.type() == ASTNode::NODE_OBJECT
&& value.cast<ASTObject>()->object()->type() == PycObject::TYPE_NONE);
if (value_is_none) {
value = stack.top();
stack.pop();
}
}
curblock->append(new ASTReturn(value));

if ((curblock->blktype() == ASTBlock::BLK_IF
Expand Down Expand Up @@ -2153,6 +2253,10 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
&& !curblock->inited()) {
curblock.cast<ASTWithBlock>()->setExpr(value);
curblock.cast<ASTWithBlock>()->setVar(name);
} else if (stack.top().type() == ASTNode::NODE_IMPORT) {
PycRef<ASTImport> import = stack.top().cast<ASTImport>();

import->add_store(new ASTStore(value, name));
} else if (value.type() == ASTNode::NODE_CHAINSTORE) {
append_to_chain_store(value, name, stack, curblock);
} else {
Expand Down Expand Up @@ -2470,7 +2574,11 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
{
PycRef<ASTNode> value = stack.top();
stack.pop();
curblock->append(new ASTReturn(value, ASTReturn::YIELD));
if (curblock->blktype() == ASTBlock::BLK_FOR && curblock.cast<ASTIterBlock>()->isComprehension()) {
stack.push(new ASTComprehension(value, ASTComprehension::GENEXPR));
} else {
curblock->append(new ASTReturn(value, ASTReturn::YIELD));
}
}
break;
case Pyc::SETUP_ANNOTATIONS:
Expand Down Expand Up @@ -2943,8 +3051,7 @@ void print_src(PycRef<ASTNode> node, PycModule* mod, std::ostream& pyc_output)
case ASTNode::NODE_COMPREHENSION:
{
PycRef<ASTComprehension> comp = node.cast<ASTComprehension>();

pyc_output << "[ ";
pyc_output << (comp->comprehensionType() == ASTComprehension::LISTCOMP ? "[ " : "( ");
print_src(comp->result(), mod, pyc_output);

for (const auto& gen : comp->generators()) {
Expand All @@ -2957,7 +3064,7 @@ void print_src(PycRef<ASTNode> node, PycModule* mod, std::ostream& pyc_output)
print_src(gen->condition(), mod, pyc_output);
}
}
pyc_output << " ]";
pyc_output << (comp->comprehensionType() == ASTComprehension::LISTCOMP ? " ]" : " )");
}
break;
case ASTNode::NODE_MAP:
Expand Down