diff --git a/sqlmesh/utils/jinja.py b/sqlmesh/utils/jinja.py index 59e9f6dd2f..240b183391 100644 --- a/sqlmesh/utils/jinja.py +++ b/sqlmesh/utils/jinja.py @@ -133,6 +133,12 @@ def find_call_names(node: nodes.Node, vars_in_scope: t.Set[str]) -> t.Iterator[C vars_in_scope = vars_in_scope.copy() for child_node in node.iter_child_nodes(): if "target" in child_node.fields: + # For nodes with assignment targets (Assign, AssignBlock, For, Import), + # the target name could shadow a reference in the right hand side. + # So we need to process the RHS before adding the target to scope. + # For example: {% set model = model.path %} should track model.path. + yield from find_call_names(child_node, vars_in_scope) + target = getattr(child_node, "target") if isinstance(target, nodes.Name): vars_in_scope.add(target.name) @@ -149,7 +155,9 @@ def find_call_names(node: nodes.Node, vars_in_scope: t.Set[str]) -> t.Iterator[C name = call_name(child_node) if name[0][0] != "'" and name[0] not in vars_in_scope: yield (name, child_node) - yield from find_call_names(child_node, vars_in_scope) + + if "target" not in child_node.fields: + yield from find_call_names(child_node, vars_in_scope) def extract_call_names( diff --git a/tests/dbt/test_manifest.py b/tests/dbt/test_manifest.py index e2e7bc706c..2ecf8b8980 100644 --- a/tests/dbt/test_manifest.py +++ b/tests/dbt/test_manifest.py @@ -324,3 +324,46 @@ def test_macro_depenency_none_str(): # "None" macro shouldn't raise a KeyError _macro_references(helper._manifest, node) + + +@pytest.mark.xdist_group("dbt_manifest") +def test_macro_assignment_shadowing(create_empty_project): + project_name = "local" + project_path, models_path = create_empty_project(project_name=project_name) + + macros_path = project_path / "macros" + macros_path.mkdir() + + (macros_path / "model_path_macro.sql").write_text(""" +{% macro model_path_macro() %} + {% if execute %} + {% set model = model.path.split('/')[-1].replace('.sql', '') %} + SELECT '{{ model }}' as model_name + {% else %} + SELECT 'placeholder' as placeholder + {% endif %} +{% endmacro %} +""") + + (models_path / "model_using_path_macro.sql").write_text(""" +{{ model_path_macro() }} +""") + + context = DbtContext(project_path) + profile = Profile.load(context) + + helper = ManifestHelper( + project_path, + project_path, + project_name, + profile.target, + model_defaults=ModelDefaultsConfig(start="2020-01-01"), + ) + + macros = helper.macros(project_name) + assert "model_path_macro" in macros + assert "path" in macros["model_path_macro"].dependencies.model_attrs.attrs + + models = helper.models() + assert "model_using_path_macro" in models + assert "path" in models["model_using_path_macro"].dependencies.model_attrs.attrs