diff --git a/sqlmesh/core/config/loader.py b/sqlmesh/core/config/loader.py index 75915800e6..2d202cb276 100644 --- a/sqlmesh/core/config/loader.py +++ b/sqlmesh/core/config/loader.py @@ -177,6 +177,7 @@ def load_config_from_paths( dbt_profile_name=kwargs.pop("profile", None), dbt_target_name=kwargs.pop("target", None), variables=variables, + threads=kwargs.pop("threads", None), ) if type(dbt_python_config) != config_type: dbt_python_config = convert_config_type(dbt_python_config, config_type) diff --git a/sqlmesh/dbt/loader.py b/sqlmesh/dbt/loader.py index eb117a3e40..39973776a8 100644 --- a/sqlmesh/dbt/loader.py +++ b/sqlmesh/dbt/loader.py @@ -49,6 +49,7 @@ def sqlmesh_config( dbt_profile_name: t.Optional[str] = None, dbt_target_name: t.Optional[str] = None, variables: t.Optional[t.Dict[str, t.Any]] = None, + threads: t.Optional[int] = None, register_comments: t.Optional[bool] = None, **kwargs: t.Any, ) -> Config: @@ -67,6 +68,10 @@ def sqlmesh_config( if not issubclass(loader, DbtLoader): raise ConfigError("The loader must be a DbtLoader.") + if threads is not None: + # the to_sqlmesh() function on TargetConfig maps self.threads -> concurrent_tasks + profile.target.threads = threads + return Config( loader=loader, model_defaults=model_defaults, diff --git a/sqlmesh_dbt/cli.py b/sqlmesh_dbt/cli.py index 83230de3fd..ec11e7730e 100644 --- a/sqlmesh_dbt/cli.py +++ b/sqlmesh_dbt/cli.py @@ -8,11 +8,13 @@ import functools -def _get_dbt_operations(ctx: click.Context, vars: t.Optional[t.Dict[str, t.Any]]) -> DbtOperations: +def _get_dbt_operations( + ctx: click.Context, vars: t.Optional[t.Dict[str, t.Any]], threads: t.Optional[int] = None +) -> DbtOperations: if not isinstance(ctx.obj, functools.partial): raise ValueError(f"Unexpected click context object: {type(ctx.obj)}") - dbt_operations = ctx.obj(vars=vars) + dbt_operations = ctx.obj(vars=vars, threads=threads) if not isinstance(dbt_operations, DbtOperations): raise ValueError(f"Unexpected dbt operations type: {type(dbt_operations)}") @@ -128,16 +130,22 @@ def dbt( @click.option( "--empty/--no-empty", default=False, help="If specified, limit input refs and sources" ) +@click.option( + "--threads", + type=int, + help="Specify number of threads to use while executing models. Overrides settings in profiles.yml.", +) @vars_option @click.pass_context def run( ctx: click.Context, vars: t.Optional[t.Dict[str, t.Any]], + threads: t.Optional[int], env: t.Optional[str] = None, **kwargs: t.Any, ) -> None: """Compile SQL and execute against the current target database.""" - _get_dbt_operations(ctx, vars).run(environment=env, **kwargs) + _get_dbt_operations(ctx, vars, threads).run(environment=env, **kwargs) @dbt.command(name="list") diff --git a/sqlmesh_dbt/operations.py b/sqlmesh_dbt/operations.py index 6e8b452b28..cb1ac217cc 100644 --- a/sqlmesh_dbt/operations.py +++ b/sqlmesh_dbt/operations.py @@ -235,6 +235,7 @@ def create( profile: t.Optional[str] = None, target: t.Optional[str] = None, vars: t.Optional[t.Dict[str, t.Any]] = None, + threads: t.Optional[int] = None, debug: bool = False, ) -> DbtOperations: with Progress(transient=True) as progress: @@ -265,7 +266,9 @@ def create( sqlmesh_context = Context( paths=[project_dir], - config_loader_kwargs=dict(profile=profile, target=target, variables=vars), + config_loader_kwargs=dict( + profile=profile, target=target, variables=vars, threads=threads + ), load=True, # DbtSelector selects based on dbt model fqn's rather than SQLMesh model names selector=DbtSelector, diff --git a/tests/dbt/cli/test_operations.py b/tests/dbt/cli/test_operations.py index b23c87882a..139336297c 100644 --- a/tests/dbt/cli/test_operations.py +++ b/tests/dbt/cli/test_operations.py @@ -333,3 +333,33 @@ def test_run_option_full_refresh_with_selector(jaffle_shop_duckdb: Path): assert not plan.empty_backfill assert not plan.skip_backfill assert plan.models_to_backfill == set(['"jaffle_shop"."main"."stg_customers"']) + + +def test_create_sets_concurrent_tasks_based_on_threads(create_empty_project: EmptyProjectCreator): + project_dir, _ = create_empty_project(project_name="test") + + # add a postgres target because duckdb overrides to concurrent_tasks=1 regardless of what gets specified + profiles_yml_file = project_dir / "profiles.yml" + profiles_yml = yaml.load(profiles_yml_file) + profiles_yml["test"]["outputs"]["postgres"] = { + "type": "postgres", + "host": "localhost", + "port": 5432, + "user": "postgres", + "password": "postgres", + "dbname": "test", + "schema": "test", + } + profiles_yml_file.write_text(yaml.dump(profiles_yml)) + + operations = create(project_dir=project_dir, target="postgres") + + assert operations.context.concurrent_tasks == 1 # 1 is the default + + operations = create(project_dir=project_dir, threads=16, target="postgres") + + assert operations.context.concurrent_tasks == 16 + assert all( + g.connection and g.connection.concurrent_tasks == 16 + for g in operations.context.config.gateways.values() + ) diff --git a/tests/dbt/cli/test_run.py b/tests/dbt/cli/test_run.py index 755553bb57..4fdb7a0cdb 100644 --- a/tests/dbt/cli/test_run.py +++ b/tests/dbt/cli/test_run.py @@ -83,3 +83,11 @@ def test_run_with_changes_and_full_refresh( ("foo", "bar", "changed"), ("baz", "bing", "changed"), ] + + +def test_run_with_threads(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]): + result = invoke_cli(["run", "--threads", "4"]) + assert result.exit_code == 0 + assert not result.exception + + assert "Model batches executed" in result.output