Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion sqlmesh/utils/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ def find_call_names(node: nodes.Node, vars_in_scope: t.Set[str]) -> t.Iterator[C
vars_in_scope = vars_in_scope.copy()
for child_node in node.iter_child_nodes():
if "target" in child_node.fields:
# For nodes with assignment targets (Assign, AssignBlock, For, Import),
# the target name could shadow a reference in the right hand side.
# So we need to process the RHS before adding the target to scope.
# For example: {% set model = model.path %} should track model.path.
yield from find_call_names(child_node, vars_in_scope)

target = getattr(child_node, "target")
if isinstance(target, nodes.Name):
vars_in_scope.add(target.name)
Expand All @@ -149,7 +155,9 @@ def find_call_names(node: nodes.Node, vars_in_scope: t.Set[str]) -> t.Iterator[C
name = call_name(child_node)
if name[0][0] != "'" and name[0] not in vars_in_scope:
yield (name, child_node)
yield from find_call_names(child_node, vars_in_scope)

if "target" not in child_node.fields:
yield from find_call_names(child_node, vars_in_scope)


def extract_call_names(
Expand Down
43 changes: 43 additions & 0 deletions tests/dbt/test_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,46 @@ def test_macro_depenency_none_str():

# "None" macro shouldn't raise a KeyError
_macro_references(helper._manifest, node)


@pytest.mark.xdist_group("dbt_manifest")
def test_macro_assignment_shadowing(create_empty_project):
project_name = "local"
project_path, models_path = create_empty_project(project_name=project_name)

macros_path = project_path / "macros"
macros_path.mkdir()

(macros_path / "model_path_macro.sql").write_text("""
{% macro model_path_macro() %}
{% if execute %}
{% set model = model.path.split('/')[-1].replace('.sql', '') %}
SELECT '{{ model }}' as model_name
{% else %}
SELECT 'placeholder' as placeholder
{% endif %}
{% endmacro %}
""")

(models_path / "model_using_path_macro.sql").write_text("""
{{ model_path_macro() }}
""")

context = DbtContext(project_path)
profile = Profile.load(context)

helper = ManifestHelper(
project_path,
project_path,
project_name,
profile.target,
model_defaults=ModelDefaultsConfig(start="2020-01-01"),
)

macros = helper.macros(project_name)
assert "model_path_macro" in macros
assert "path" in macros["model_path_macro"].dependencies.model_attrs.attrs

models = helper.models()
assert "model_using_path_macro" in models
assert "path" in models["model_using_path_macro"].dependencies.model_attrs.attrs