Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 61 additions & 3 deletions sqlmesh/dbt/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
extract_call_names,
jinja_call_arg_name,
)
from sqlglot.helper import ensure_list

if t.TYPE_CHECKING:
from dbt.contracts.graph.manifest import Macro, Manifest
Expand Down Expand Up @@ -353,15 +354,17 @@ def _load_tests(self) -> None:
)

test_model = _test_model(node)
node_config = _node_base_config(node)
node_config["name"] = _build_test_name(node, dependencies)

test = TestConfig(
sql=sql,
model_name=test_model,
test_kwargs=node.test_metadata.kwargs if hasattr(node, "test_metadata") else {},
dependencies=dependencies,
**_node_base_config(node),
**node_config,
)
self._tests_per_package[node.package_name][node.name.lower()] = test
self._tests_per_package[node.package_name][node.unique_id] = test
if test_model:
self._tests_by_owner[test_model].append(test)

Expand Down Expand Up @@ -741,7 +744,12 @@ def _test_model(node: ManifestNode) -> t.Optional[str]:
attached_node = getattr(node, "attached_node", None)
if attached_node:
pieces = attached_node.split(".")
return pieces[-1] if pieces[0] in ["model", "seed"] else None
if pieces[0] in ["model", "seed"]:
# versioned models have format "model.package.model_name.v1" (4 parts)
if len(pieces) == 4:
return f"{pieces[2]}_{pieces[3]}"
return pieces[-1]
return None

key_name = getattr(node, "file_key_name", None)
if key_name:
Expand Down Expand Up @@ -798,3 +806,53 @@ def _strip_jinja_materialization_tags(materialization_jinja: str) -> str:
)

return materialization_jinja.strip()


def _build_test_name(node: ManifestNode, dependencies: Dependencies) -> str:
"""
Build a user-friendly test name that includes the test's model/source, column,
and args for tests with custom user names. Needed because dbt only generates these
names for tests that do not specify the "name" field in their YAML definition.

Name structure
- Model test: [namespace]_[test name]_[model name]_[column name]__[arg values]
- Source test: [namespace]_source_[test name]_[source name]_[table name]_[column name]__[arg values]
"""
# standalone test
if not hasattr(node, "test_metadata"):
return node.name

model_name = _test_model(node)
source_name = None
if not model_name and dependencies.sources:
# extract source and table names
source_parts = list(dependencies.sources)[0].split(".")
source_name = "_".join(source_parts) if len(source_parts) == 2 else source_parts[-1]
entity_name = model_name or source_name or ""
entity_name = f"_{entity_name}" if entity_name else ""

name_prefix = ""
if namespace := getattr(node.test_metadata, "namespace", None):
name_prefix += f"{namespace}_"
if source_name and not model_name:
name_prefix += "source_"

metadata_kwargs = node.test_metadata.kwargs
arg_val_parts = []
for arg, val in sorted(metadata_kwargs.items()):
if arg == "model":
continue
if isinstance(val, dict):
val = list(val.values())
val = [re.sub("[^0-9a-zA-Z_]+", "_", str(v)) for v in ensure_list(val)]
arg_val_parts.extend(val)
unique_args = "__".join(arg_val_parts) if arg_val_parts else ""
unique_args = f"_{unique_args}" if unique_args else ""

auto_name = f"{name_prefix}{node.test_metadata.name}{entity_name}{unique_args}"

if node.name == auto_name:
return node.name

custom_prefix = name_prefix if source_name and not model_name else ""
return f"{custom_prefix}{node.name}{entity_name}{unique_args}"
9 changes: 8 additions & 1 deletion sqlmesh/dbt/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,14 @@ def is_standalone(self) -> bool:
return True

# Check if test has references to other models
other_refs = {ref for ref in self.dependencies.refs if ref != self.model_name}
# For versioned models, refs include version (e.g., "model_name_v1") but model_name may not
self_refs = {self.model_name}
for ref in self.dependencies.refs:
# versioned models end in _vX
if ref.startswith(f"{self.model_name}_v"):
self_refs.add(ref)

other_refs = {ref for ref in self.dependencies.refs if ref not in self_refs}
return bool(other_refs)

@property
Expand Down
132 changes: 132 additions & 0 deletions tests/dbt/test_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from pathlib import Path

import pytest

from sqlmesh.dbt.test import TestConfig


Expand All @@ -8,3 +12,131 @@ def test_multiline_test_kwarg() -> None:
test_kwargs={"test_field": "foo\nbar\n"},
)
assert test._kwargs() == 'test_field="foo\nbar"'


@pytest.mark.xdist_group("dbt_manifest")
def test_tests_get_unique_names(tmp_path: Path, create_empty_project) -> None:
from sqlmesh.utils.yaml import YAML
from sqlmesh.core.context import Context

yaml = YAML()
project_dir, model_dir = create_empty_project(project_name="local")

model_file = model_dir / "my_model.sql"
with open(model_file, "w", encoding="utf-8") as f:
f.write("SELECT 1 as id, 'value1' as status")

# Create schema.yml with:
# 1. Same test on model and source, both with/without custom test name
# 2. Same test on same model with different args, both with/without custom test name
# 3. Versioned model with tests (both built-in and custom named)
schema_yaml = {
"version": 2,
"sources": [
{
"name": "raw",
"tables": [
{
"name": "my_source",
"columns": [
{
"name": "id",
"data_tests": [
{"not_null": {"name": "custom_notnull_name"}},
{"not_null": {}},
],
}
],
}
],
}
],
"models": [
{
"name": "my_model",
"columns": [
{
"name": "id",
"data_tests": [
{"not_null": {"name": "custom_notnull_name"}},
{"not_null": {}},
],
},
{
"name": "status",
"data_tests": [
{"accepted_values": {"values": ["value1", "value2"]}},
{"accepted_values": {"values": ["value1", "value2", "value3"]}},
{
"accepted_values": {
"name": "custom_accepted_values_name",
"values": ["value1", "value2"],
}
},
{
"accepted_values": {
"name": "custom_accepted_values_name",
"values": ["value1", "value2", "value3"],
}
},
],
},
],
},
{
"name": "versioned_model",
"columns": [
{
"name": "id",
"data_tests": [
{"not_null": {}},
{"not_null": {"name": "custom_versioned_notnull"}},
],
},
{
"name": "amount",
"data_tests": [
{"accepted_values": {"values": ["low", "high"]}},
],
},
],
"versions": [
{"v": 1},
{"v": 2},
],
},
],
}

schema_file = model_dir / "schema.yml"
with open(schema_file, "w", encoding="utf-8") as f:
yaml.dump(schema_yaml, f)

# Create versioned model files
versioned_model_v1_file = model_dir / "versioned_model_v1.sql"
with open(versioned_model_v1_file, "w", encoding="utf-8") as f:
f.write("SELECT 1 as id, 'low' as amount")

versioned_model_v2_file = model_dir / "versioned_model_v2.sql"
with open(versioned_model_v2_file, "w", encoding="utf-8") as f:
f.write("SELECT 1 as id, 'low' as amount")

context = Context(paths=project_dir)

all_audit_names = list(context._audits.keys()) + list(context._standalone_audits.keys())
assert sorted(all_audit_names) == [
"local.accepted_values_my_model_status__value1__value2",
"local.accepted_values_my_model_status__value1__value2__value3",
"local.accepted_values_versioned_model_v1_amount__low__high",
"local.accepted_values_versioned_model_v2_amount__low__high",
"local.custom_accepted_values_name_my_model_status__value1__value2",
"local.custom_accepted_values_name_my_model_status__value1__value2__value3",
"local.custom_notnull_name_my_model_id",
"local.custom_versioned_notnull_versioned_model_v1_id",
"local.custom_versioned_notnull_versioned_model_v2_id",
"local.not_null_my_model_id",
"local.not_null_versioned_model_v1_id",
"local.not_null_versioned_model_v2_id",
"local.source_custom_notnull_name_raw_my_source_id",
"local.source_not_null_raw_my_source_id",
]