diff --git a/dvc/repo/run.py b/dvc/repo/run.py index 43b17f53b9..02cbf6fafd 100644 --- a/dvc/repo/run.py +++ b/dvc/repo/run.py @@ -11,7 +11,9 @@ def is_valid_name(name: str): - return not {"\\", "/", "@", ":"} & set(name) + from ..stage import INVALID_STAGENAME_CHARS + + return not INVALID_STAGENAME_CHARS & set(name) def _get_file_path(kwargs): diff --git a/dvc/stage/__init__.py b/dvc/stage/__init__.py index a1101f0715..aef2c81171 100644 --- a/dvc/stage/__init__.py +++ b/dvc/stage/__init__.py @@ -2,6 +2,7 @@ import os import pathlib import signal +import string import subprocess import threading from itertools import chain, product @@ -29,6 +30,8 @@ from .params import OutputParams logger = logging.getLogger(__name__) +# Disallow all punctuation characters except hyphen and underscore +INVALID_STAGENAME_CHARS = set(string.punctuation) - {"_", "-"} def loads_from(cls, repo, path, wdir, data): diff --git a/dvc/stage/exceptions.py b/dvc/stage/exceptions.py index 19d07aea67..a98ef154a9 100644 --- a/dvc/stage/exceptions.py +++ b/dvc/stage/exceptions.py @@ -113,7 +113,4 @@ def __init__(self, name, file): class InvalidStageName(DvcException): def __init__(self,): - super().__init__( - "Stage name cannot contain invalid characters: " - "'\\', '/', '@' and ':'." - ) + super().__init__("Stage name cannot contain punctuation characters.") diff --git a/tests/func/test_repro.py b/tests/func/test_repro.py index 10652b14d1..7c6cfcb143 100644 --- a/tests/func/test_repro.py +++ b/tests/func/test_repro.py @@ -94,14 +94,14 @@ def test(self): deps=[self.FOO], outs=["bar.txt"], cmd="echo bar > bar.txt", - name="copybarbar.txt", + name="copybarbar-txt", ) self._run( deps=["bar.txt"], outs=["baz.txt"], cmd="echo baz > baz.txt", - name="copybazbaz.txt", + name="copybazbaz-txt", ) stage_dump = { @@ -657,7 +657,7 @@ def test(self): class TestReproPhony(TestReproChangedData): def test(self): stage = self._run( - cmd="cat " + self.file1, deps=[self.file1], name="no_cmd?" + cmd="cat " + self.file1, deps=[self.file1], name="no_cmd" ) self.swap_foo_with_bar() diff --git a/tests/func/test_run_multistage.py b/tests/func/test_run_multistage.py index a84a38620d..bb8e7d489b 100644 --- a/tests/func/test_run_multistage.py +++ b/tests/func/test_run_multistage.py @@ -160,10 +160,17 @@ def test_run_dump_on_multistage(tmp_dir, dvc): } -def test_run_with_invalid_stage_name(tmp_dir, dvc, run_copy): - tmp_dir.dvc_gen("foo", "foo") +@pytest.mark.parametrize( + "char", ["@:", "#", "$", ":", "/", "\\", ".", ";", ","] +) +def test_run_with_invalid_stage_name(run_copy, char): with pytest.raises(InvalidStageName): - run_copy("foo", "bar", name="email@https://dvc.org") + run_copy("foo", "bar", name="copy_name-{}".format(char)) + + +def test_run_with_name_having_hyphen_underscore(tmp_dir, dvc, run_copy): + tmp_dir.dvc_gen("foo", "foo") + run_copy("foo", "bar", name="copy-foo_bar") def test_run_already_exists(tmp_dir, dvc, run_copy):