From 1b7a40bae526bf0386faf3117e91c782cac64525 Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Wed, 9 Jun 2021 15:56:39 -0700 Subject: [PATCH 1/6] Add --label argument to label new PRs --- ghstack/__main__.py | 5 +++++ ghstack/github_fake.py | 47 ++++++++++++++++++++++++++++++++++++++++++ ghstack/submit.py | 12 +++++++++++ test_ghstack.py | 34 +++++++++++++++++++++++++++++- 4 files changed, 97 insertions(+), 1 deletion(-) diff --git a/ghstack/__main__.py b/ghstack/__main__.py index f331d2e..0f4be40 100755 --- a/ghstack/__main__.py +++ b/ghstack/__main__.py @@ -52,6 +52,10 @@ def main() -> None: subparser.add_argument( '--draft', action='store_true', help='Create the pull request in draft mode (only if it has not already been created)') + subparser.add_argument( + '--label', action='append', default=[], + help='Add this label to any newly created pull requests ' + '(multiple --label arguments can be given)') unlink = subparsers.add_parser('unlink') unlink.add_argument('COMMITS', nargs='*') @@ -108,6 +112,7 @@ def main() -> None: force=args.force, no_skip=args.no_skip, draft=args.draft, + labels=args.label, github_url=conf.github_url, remote_name=conf.remote_name, ) diff --git a/ghstack/github_fake.py b/ghstack/github_fake.py index 3a650d1..4a06373 100644 --- a/ghstack/github_fake.py +++ b/ghstack/github_fake.py @@ -35,10 +35,29 @@ 'maintainer_can_modify': bool, }) +AddLabelsInput = TypedDict('AddLabelsInput', { + 'labels': List[str], +}) + CreatePullRequestPayload = TypedDict('CreatePullRequestPayload', { 'number': int, }) +# omitting many of these fields because we don't use them +Label = TypedDict('Label', { + # 'id': int, + # 'node_id': str, + # 'url': str, + 'name': str, + # 'description': str, + # 'color': str, + # 'default': bool, +}) + +AddLabelsPayload = List[Label] + +ListLabelsPayload = List[Label] + # The "database" for our mock instance class GitHubState: @@ -213,6 +232,7 @@ class PullRequest(Node): # state: PullRequestState title: str url: str + labels: List[Label] def repository(self, info: GraphQLResolveInfo) -> Repository: return github_state(info).repositories[self._repository] @@ -313,6 +333,7 @@ def _create_pull(self, owner: str, name: str, headRefName=input['head'], title=input['title'], body=input['body'], + labels=[], ) # TODO: compute files changed state.pull_requests[id] = pr @@ -347,12 +368,33 @@ def _set_default_branch(self, owner: str, name: str, repo = state.repository(owner, name) repo.defaultBranchRef = repo._make_ref(state, input['default_branch']) + def _add_labels(self, owner: str, name: str, number: GitHubNumber, + input: AddLabelsInput) -> AddLabelsPayload: + state = self.state + repo = state.repository(owner, name) + pr = state.pull_request(repo, number) + pr.labels += [{'name': label} for label in input['labels']] + return pr.labels + + def _list_labels(self, owner: str, name: str, number: GitHubNumber) -> ListLabelsPayload: + state = self.state + repo = state.repository(owner, name) + pr = state.pull_request(repo, number) + return pr.labels + def rest(self, method: str, path: str, **kwargs: Any) -> Any: + labels_re = r'^repos/([^/]+)/([^/]+)/issues/([^/]+)/labels$' if method == 'post': m = re.match(r'^repos/([^/]+)/([^/]+)/pulls$', path) if m: return self._create_pull(m.group(1), m.group(2), cast(CreatePullRequestInput, kwargs)) + m = re.match(labels_re, path) + if m: + owner, name, number = m.groups() + return self._add_labels( + owner, name, GitHubNumber(int(number)), + cast(AddLabelsInput, kwargs)) elif method == 'patch': m = re.match(r'^repos/([^/]+)/([^/]+)(?:/pulls/([^/]+))?$', path) if m: @@ -365,6 +407,11 @@ def rest(self, method: str, path: str, **kwargs: Any) -> Any: return self._set_default_branch( owner, name, cast(SetDefaultBranchInput, kwargs)) + elif method == 'get': + m = re.match(labels_re, path) + if m: + owner, name, number = m.groups() + return self._list_labels(owner, name, GitHubNumber(int(number))) raise NotImplementedError( "FakeGitHubEndpoint REST {} {} not implemented" .format(method.upper(), path) diff --git a/ghstack/submit.py b/ghstack/submit.py index 410e0d6..012b973 100644 --- a/ghstack/submit.py +++ b/ghstack/submit.py @@ -120,6 +120,7 @@ def main(*, force: bool = False, no_skip: bool = False, draft: bool = False, + labels: List[str], github_url: str, remote_name: str ) -> List[Optional[DiffMeta]]: @@ -185,6 +186,7 @@ def main(*, force=force, no_skip=no_skip, draft=draft, + labels=labels, stack=list(reversed(stack)), github_url=github_url, remote_name=remote_name) @@ -296,6 +298,9 @@ class Submitter(object): # Create the PR in draft mode if it is going to be created (and not updated). draft: bool + # Add these labels to newly created PRs + labels: List[str] + # Github url (normally github.com) github_url: str @@ -321,6 +326,7 @@ def __init__( force: bool, no_skip: bool, draft: bool, + labels: List[str], github_url: str, remote_name: str): self.github = github @@ -344,6 +350,7 @@ def __init__( self.force = force self.no_skip = no_skip self.draft = draft + self.labels = labels self.github_url = github_url self.remote_name = remote_name @@ -608,6 +615,11 @@ def process_new_commit(self, commit: ghstack.diff.Diff) -> None: draft=self.draft, ) number = r['number'] + if len(self.labels) > 0: + self.github.post( + f"repos/{self.repo_owner}/{self.repo_name}/issues/{number}/labels", + labels=self.labels, + ) logging.info("Opened PR #{}".format(number)) diff --git a/test_ghstack.py b/test_ghstack.py index 9d98342..0bb1bbd 100644 --- a/test_ghstack.py +++ b/test_ghstack.py @@ -97,7 +97,8 @@ def substituteRev(self, rev: str, substitute: str) -> None: def gh(self, msg: str = 'Update', update_fields: bool = False, short: bool = False, - no_skip: bool = False) -> List[Optional[ghstack.submit.DiffMeta]]: + no_skip: bool = False, + labels: List[str] = []) -> List[Optional[ghstack.submit.DiffMeta]]: return ghstack.submit.main( msg=msg, username='ezyang', @@ -109,6 +110,7 @@ def gh(self, msg: str = 'Update', repo_name='pytorch', short=short, no_skip=no_skip, + labels=labels, github_url='github.com', remote_name='origin') @@ -2005,6 +2007,36 @@ def test_default_branch_change(self) -> None: rUP1 Commit 1 rINI0 Initial commit''') + # ------------------------------------------------------------------------- # + + def test_labels(self) -> None: + # first commit + self.writeFileAndAdd('file1.txt', 'A') + self.sh.git('commit', '-m', 'Commit 1') + self.sh.test_tick() + # ghstack + self.gh() + # second commit + self.writeFileAndAdd('file2.txt', 'B') + self.sh.git('commit', '-m', 'Commit 2') + self.sh.test_tick() + # third commit + self.writeFileAndAdd('file3.txt', 'C') + self.sh.git('commit', '-m', 'Commit 3') + self.sh.test_tick() + # ghstack with labels + self.gh(labels=['foo', 'bar']) + + def get_labels(n: int) -> List[str]: + labels = self.github.get(f'repos/pytorch/pytorch/issues/{n}/labels') + return [label['name'] for label in labels] + + # was already created before second ghstack run + self.assertEqual(get_labels(500), []) + # included in the second ghstack run + self.assertEqual(get_labels(501), ['foo', 'bar']) + self.assertEqual(get_labels(502), ['foo', 'bar']) + if __name__ == '__main__': logging.basicConfig(level=logging.DEBUG, format='%(message)s') From f55db3e66848b14631c8f8a9fc9f9cb856d4184e Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Fri, 11 Jun 2021 16:59:48 -0700 Subject: [PATCH 2/6] Correctly mock labels with pagination --- ghstack/github_fake.py | 77 ++++++++++++++++++++++++------------------ test_ghstack.py | 28 +++++++++++++-- 2 files changed, 70 insertions(+), 35 deletions(-) diff --git a/ghstack/github_fake.py b/ghstack/github_fake.py index 4a06373..81de9c4 100644 --- a/ghstack/github_fake.py +++ b/ghstack/github_fake.py @@ -43,21 +43,6 @@ 'number': int, }) -# omitting many of these fields because we don't use them -Label = TypedDict('Label', { - # 'id': int, - # 'node_id': str, - # 'url': str, - 'name': str, - # 'description': str, - # 'color': str, - # 'default': bool, -}) - -AddLabelsPayload = List[Label] - -ListLabelsPayload = List[Label] - # The "database" for our mock instance class GitHubState: @@ -217,6 +202,22 @@ def repository(self, info: GraphQLResolveInfo) -> Repository: return github_state(info).repositories[self._repository] +@dataclass +class Label(Node): + name: str + + +@dataclass +class PageInfo: + endCursor: Optional[str] + + +@dataclass +class LabelConnection: + nodes: Optional[List[Optional[Label]]] + pageInfo: PageInfo + + @dataclass class PullRequest(Node): baseRef: Optional[Ref] @@ -226,17 +227,37 @@ class PullRequest(Node): headRef: Optional[Ref] headRefName: str # headRepository: Optional[Repository] + _labels: List[Label] # maintainerCanModify: bool number: GitHubNumber _repository: GraphQLId # cycle breaker # state: PullRequestState title: str url: str - labels: List[Label] def repository(self, info: GraphQLResolveInfo) -> Repository: return github_state(info).repositories[self._repository] + def labels(self, info: GraphQLResolveInfo, + after: Optional[str] = None, before: Optional[str] = None, + first: Optional[int] = None, last: Optional[int] = None, + ) -> Optional[LabelConnection]: + if first is None: + # the real API also supports `last`, but we do not + raise RuntimeError( + "You must provide a `first` value" + "to properly paginate the `labels` connection." + ) + # the real API uses a more sophisticated base64-encoded syntax + # for cursors, but this serves our purposes well enough + start = int(after) if after else 0 + result: Sequence[Optional[Label]] = self._labels[start:start + first] + cursor = str(start + len(result)) if result else None + return LabelConnection( + nodes=list(result), + pageInfo=PageInfo(endCursor=cursor), + ) + @dataclass class PullRequestConnection: @@ -331,9 +352,9 @@ def _create_pull(self, owner: str, name: str, baseRefName=input['base'], headRef=headRef, headRefName=input['head'], + _labels=[], title=input['title'], body=input['body'], - labels=[], ) # TODO: compute files changed state.pull_requests[id] = pr @@ -368,28 +389,23 @@ def _set_default_branch(self, owner: str, name: str, repo = state.repository(owner, name) repo.defaultBranchRef = repo._make_ref(state, input['default_branch']) + # NB: This technically does have a payload, but we don't + # use it so I didn't bother constructing it. def _add_labels(self, owner: str, name: str, number: GitHubNumber, - input: AddLabelsInput) -> AddLabelsPayload: + input: AddLabelsInput) -> None: state = self.state repo = state.repository(owner, name) pr = state.pull_request(repo, number) - pr.labels += [{'name': label} for label in input['labels']] - return pr.labels - - def _list_labels(self, owner: str, name: str, number: GitHubNumber) -> ListLabelsPayload: - state = self.state - repo = state.repository(owner, name) - pr = state.pull_request(repo, number) - return pr.labels + for name in input['labels']: + pr._labels.append(Label(id=state.next_id(), name=name)) def rest(self, method: str, path: str, **kwargs: Any) -> Any: - labels_re = r'^repos/([^/]+)/([^/]+)/issues/([^/]+)/labels$' if method == 'post': m = re.match(r'^repos/([^/]+)/([^/]+)/pulls$', path) if m: return self._create_pull(m.group(1), m.group(2), cast(CreatePullRequestInput, kwargs)) - m = re.match(labels_re, path) + m = re.match(r'^repos/([^/]+)/([^/]+)/issues/([^/]+)/labels$', path) if m: owner, name, number = m.groups() return self._add_labels( @@ -407,11 +423,6 @@ def rest(self, method: str, path: str, **kwargs: Any) -> Any: return self._set_default_branch( owner, name, cast(SetDefaultBranchInput, kwargs)) - elif method == 'get': - m = re.match(labels_re, path) - if m: - owner, name, number = m.groups() - return self._list_labels(owner, name, GitHubNumber(int(number))) raise NotImplementedError( "FakeGitHubEndpoint REST {} {} not implemented" .format(method.upper(), path) diff --git a/test_ghstack.py b/test_ghstack.py index 0bb1bbd..d05a7a4 100644 --- a/test_ghstack.py +++ b/test_ghstack.py @@ -2028,8 +2028,32 @@ def test_labels(self) -> None: self.gh(labels=['foo', 'bar']) def get_labels(n: int) -> List[str]: - labels = self.github.get(f'repos/pytorch/pytorch/issues/{n}/labels') - return [label['name'] for label in labels] + cursor = None + labels = [] + while True: + page = self.github.graphql(""" + query ($number: Int!, $cursor: String) { + repository(owner: "pytorch", name: "pytorch") { + pullRequest(number: $number) { + labels(first: 5, after: $cursor) { + nodes { + name + } + pageInfo { + endCursor + } + } + } + } + } + """, number=n, cursor=cursor)['data']['repository']['pullRequest']['labels'] + nodes = page['nodes'] + if nodes: + labels += [label['name'] for label in nodes] + cursor = page['pageInfo']['endCursor'] + else: + break + return labels # was already created before second ghstack run self.assertEqual(get_labels(500), []) From e3641542d0aabfb6381a974b7c686c365fe39c45 Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Fri, 11 Jun 2021 17:03:13 -0700 Subject: [PATCH 3/6] Turn down the pagination, for better code coverage --- test_ghstack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_ghstack.py b/test_ghstack.py index d05a7a4..c0256cd 100644 --- a/test_ghstack.py +++ b/test_ghstack.py @@ -2035,7 +2035,7 @@ def get_labels(n: int) -> List[str]: query ($number: Int!, $cursor: String) { repository(owner: "pytorch", name: "pytorch") { pullRequest(number: $number) { - labels(first: 5, after: $cursor) { + labels(first: 1, after: $cursor) { nodes { name } From c8f95a0f0d59d911ccdbd0368e21fb9a7947d09e Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Mon, 14 Jun 2021 13:47:35 -0700 Subject: [PATCH 4/6] Apply labels to all PRs in stack --- ghstack/__main__.py | 2 +- ghstack/github_fake.py | 26 ++++++++++++++---------- ghstack/submit.py | 12 +++++------ poetry.lock | 2 +- pyproject.toml | 1 + test_ghstack.py | 46 ++++++++++++++++++++++-------------------- 6 files changed, 48 insertions(+), 41 deletions(-) diff --git a/ghstack/__main__.py b/ghstack/__main__.py index 0f4be40..f896bc2 100755 --- a/ghstack/__main__.py +++ b/ghstack/__main__.py @@ -54,7 +54,7 @@ def main() -> None: help='Create the pull request in draft mode (only if it has not already been created)') subparser.add_argument( '--label', action='append', default=[], - help='Add this label to any newly created pull requests ' + help='Add this label to all pull requests in the stack ' '(multiple --label arguments can be given)') unlink = subparsers.add_parser('unlink') diff --git a/ghstack/github_fake.py b/ghstack/github_fake.py index 81de9c4..d34b5fc 100644 --- a/ghstack/github_fake.py +++ b/ghstack/github_fake.py @@ -3,9 +3,11 @@ import os.path import re from dataclasses import dataclass # Oof! Python 3.7 only!! +from itertools import islice from typing import Any, Dict, List, NewType, Optional, Sequence, cast import graphql +from sortedcontainers import SortedKeyList # type: ignore[import] from typing_extensions import TypedDict import ghstack.github @@ -219,7 +221,7 @@ class LabelConnection: @dataclass -class PullRequest(Node): +class PullRequest(Node): # type: ignore[no-any-unimported] baseRef: Optional[Ref] baseRefName: str body: str @@ -227,7 +229,7 @@ class PullRequest(Node): headRef: Optional[Ref] headRefName: str # headRepository: Optional[Repository] - _labels: List[Label] + _labels: SortedKeyList # type: ignore[no-any-unimported] # maintainerCanModify: bool number: GitHubNumber _repository: GraphQLId # cycle breaker @@ -248,14 +250,13 @@ def labels(self, info: GraphQLResolveInfo, "You must provide a `first` value" "to properly paginate the `labels` connection." ) - # the real API uses a more sophisticated base64-encoded syntax - # for cursors, but this serves our purposes well enough - start = int(after) if after else 0 - result: Sequence[Optional[Label]] = self._labels[start:start + first] - cursor = str(start + len(result)) if result else None + nodes = list(islice( + self._labels.irange_key(after, inclusive=(False, True)), + first, + )) return LabelConnection( - nodes=list(result), - pageInfo=PageInfo(endCursor=cursor), + nodes=nodes, + pageInfo=PageInfo(endCursor=nodes[-1].name if nodes else None), ) @@ -352,7 +353,7 @@ def _create_pull(self, owner: str, name: str, baseRefName=input['base'], headRef=headRef, headRefName=input['head'], - _labels=[], + _labels=SortedKeyList(key=lambda label: label.name), title=input['title'], body=input['body'], ) @@ -396,8 +397,11 @@ def _add_labels(self, owner: str, name: str, number: GitHubNumber, state = self.state repo = state.repository(owner, name) pr = state.pull_request(repo, number) + labels = pr._labels for name in input['labels']: - pr._labels.append(Label(id=state.next_id(), name=name)) + # https://stackoverflow.com/a/3114640 + if not any(True for _ in labels.irange_key(name, name)): + labels.add(Label(id=state.next_id(), name=name)) def rest(self, method: str, path: str, **kwargs: Any) -> Any: if method == 'post': diff --git a/ghstack/submit.py b/ghstack/submit.py index 012b973..57b1c4c 100644 --- a/ghstack/submit.py +++ b/ghstack/submit.py @@ -298,7 +298,7 @@ class Submitter(object): # Create the PR in draft mode if it is going to be created (and not updated). draft: bool - # Add these labels to newly created PRs + # Add these labels to all PRs in the stack labels: List[str] # Github url (normally github.com) @@ -615,11 +615,6 @@ def process_new_commit(self, commit: ghstack.diff.Diff) -> None: draft=self.draft, ) number = r['number'] - if len(self.labels) > 0: - self.github.post( - f"repos/{self.repo_owner}/{self.repo_name}/issues/{number}/labels", - labels=self.labels, - ) logging.info("Opened PR #{}".format(number)) @@ -952,6 +947,11 @@ def push_updates(self, *, import_help: bool = True) -> None: # noqa: C901 number=s.number), body=RE_STACK.sub(self._format_stack(i), s.body), title=s.title) + if len(self.labels) > 0: + self.github.post( + f"repos/{self.repo_owner}/{self.repo_name}/issues/{s.number}/labels", + labels=self.labels, + ) else: logging.info( "# Skipping closed https://{github_url}/{owner}/{repo}/pull/{number}" diff --git a/poetry.lock b/poetry.lock index 1ce85c6..cb11580 100644 --- a/poetry.lock +++ b/poetry.lock @@ -396,7 +396,7 @@ testing = ["pytest (>=3.5,!=3.7.3)", "pytest-checkdocs (>=1.2.3)", "pytest-flake [metadata] lock-version = "1.1" python-versions = "^3.6" -content-hash = "514a97557ab89b5b5bed7a23d296e97a7d4edda529c9a0c9860c422fdc640f7f" +content-hash = "e8f8b09d48f96771b5608841269b4c3f50547407c240297ddeb027020e0b5ddd" [metadata.files] aiohttp = [ diff --git a/pyproject.toml b/pyproject.toml index 88b6317..2b4ee48 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ hypothesis = "^6" isort = "^5" mypy = "^0.800" pytest = "^6" +sortedcontainers = "^2" [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/test_ghstack.py b/test_ghstack.py index c0256cd..885076a 100644 --- a/test_ghstack.py +++ b/test_ghstack.py @@ -2010,23 +2010,6 @@ def test_default_branch_change(self) -> None: # ------------------------------------------------------------------------- # def test_labels(self) -> None: - # first commit - self.writeFileAndAdd('file1.txt', 'A') - self.sh.git('commit', '-m', 'Commit 1') - self.sh.test_tick() - # ghstack - self.gh() - # second commit - self.writeFileAndAdd('file2.txt', 'B') - self.sh.git('commit', '-m', 'Commit 2') - self.sh.test_tick() - # third commit - self.writeFileAndAdd('file3.txt', 'C') - self.sh.git('commit', '-m', 'Commit 3') - self.sh.test_tick() - # ghstack with labels - self.gh(labels=['foo', 'bar']) - def get_labels(n: int) -> List[str]: cursor = None labels = [] @@ -2035,7 +2018,7 @@ def get_labels(n: int) -> List[str]: query ($number: Int!, $cursor: String) { repository(owner: "pytorch", name: "pytorch") { pullRequest(number: $number) { - labels(first: 1, after: $cursor) { + labels(first: 2, after: $cursor) { nodes { name } @@ -2055,11 +2038,30 @@ def get_labels(n: int) -> List[str]: break return labels - # was already created before second ghstack run + self.writeFileAndAdd('file1.txt', 'A') + self.sh.git('commit', '-m', 'Commit 1') + self.sh.test_tick() + self.gh() + self.assertEqual(get_labels(500), []) - # included in the second ghstack run - self.assertEqual(get_labels(501), ['foo', 'bar']) - self.assertEqual(get_labels(502), ['foo', 'bar']) + + self.writeFileAndAdd('file2.txt', 'B') + self.sh.git('commit', '-m', 'Commit 2') + self.sh.test_tick() + self.gh(labels=['foo', 'bar']) + + # alphabetical order + self.assertEqual(get_labels(500), ['bar', 'foo']) + self.assertEqual(get_labels(501), ['bar', 'foo']) + + self.writeFileAndAdd('file3.txt', 'C') + self.sh.git('commit', '-m', 'Commit 3') + self.sh.test_tick() + self.gh(labels=['foo', 'baz']) + + self.assertEqual(get_labels(500), ['bar', 'baz', 'foo']) + self.assertEqual(get_labels(501), ['bar', 'baz', 'foo']) + self.assertEqual(get_labels(502), ['baz', 'foo']) if __name__ == '__main__': From ba6e2c3bfe344654b05a6bf6b1ca67c9eb478365 Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Tue, 15 Jun 2021 09:13:13 -0700 Subject: [PATCH 5/6] Don't use a mutable argument default --- test_ghstack.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test_ghstack.py b/test_ghstack.py index 885076a..bf95304 100644 --- a/test_ghstack.py +++ b/test_ghstack.py @@ -98,7 +98,8 @@ def gh(self, msg: str = 'Update', update_fields: bool = False, short: bool = False, no_skip: bool = False, - labels: List[str] = []) -> List[Optional[ghstack.submit.DiffMeta]]: + labels: Optional[List[str]] = None + ) -> List[Optional[ghstack.submit.DiffMeta]]: return ghstack.submit.main( msg=msg, username='ezyang', @@ -110,7 +111,7 @@ def gh(self, msg: str = 'Update', repo_name='pytorch', short=short, no_skip=no_skip, - labels=labels, + labels=labels or [], github_url='github.com', remote_name='origin') From a80998f4461d98edbfdf8bf4a54b81768a2f3090 Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Tue, 15 Jun 2021 09:19:23 -0700 Subject: [PATCH 6/6] Simplify test_labels helper function --- test_ghstack.py | 50 ++++++++++++++++++++----------------------------- 1 file changed, 20 insertions(+), 30 deletions(-) diff --git a/test_ghstack.py b/test_ghstack.py index bf95304..5b9f3f1 100644 --- a/test_ghstack.py +++ b/test_ghstack.py @@ -2011,40 +2011,30 @@ def test_default_branch_change(self) -> None: # ------------------------------------------------------------------------- # def test_labels(self) -> None: - def get_labels(n: int) -> List[str]: - cursor = None - labels = [] - while True: - page = self.github.graphql(""" - query ($number: Int!, $cursor: String) { - repository(owner: "pytorch", name: "pytorch") { - pullRequest(number: $number) { - labels(first: 2, after: $cursor) { - nodes { - name - } - pageInfo { - endCursor - } - } + def assert_labels(pr: int, expected: List[str]) -> None: + raw = self.github.graphql(""" + query ($pr: Int!, $first: Int!) { + repository(owner: "pytorch", name: "pytorch") { + pullRequest(number: $pr) { + labels(first: $first) { + nodes { + name } } } - """, number=n, cursor=cursor)['data']['repository']['pullRequest']['labels'] - nodes = page['nodes'] - if nodes: - labels += [label['name'] for label in nodes] - cursor = page['pageInfo']['endCursor'] - else: - break - return labels + } + } + """, pr=pr, first=len(expected) + 1) + nodes = raw['data']['repository']['pullRequest']['labels']['nodes'] + actual = [label['name'] for label in nodes] + self.assertEqual(actual, expected) self.writeFileAndAdd('file1.txt', 'A') self.sh.git('commit', '-m', 'Commit 1') self.sh.test_tick() self.gh() - self.assertEqual(get_labels(500), []) + assert_labels(500, []) self.writeFileAndAdd('file2.txt', 'B') self.sh.git('commit', '-m', 'Commit 2') @@ -2052,17 +2042,17 @@ def get_labels(n: int) -> List[str]: self.gh(labels=['foo', 'bar']) # alphabetical order - self.assertEqual(get_labels(500), ['bar', 'foo']) - self.assertEqual(get_labels(501), ['bar', 'foo']) + assert_labels(500, ['bar', 'foo']) + assert_labels(501, ['bar', 'foo']) self.writeFileAndAdd('file3.txt', 'C') self.sh.git('commit', '-m', 'Commit 3') self.sh.test_tick() self.gh(labels=['foo', 'baz']) - self.assertEqual(get_labels(500), ['bar', 'baz', 'foo']) - self.assertEqual(get_labels(501), ['bar', 'baz', 'foo']) - self.assertEqual(get_labels(502), ['baz', 'foo']) + assert_labels(500, ['bar', 'baz', 'foo']) + assert_labels(501, ['bar', 'baz', 'foo']) + assert_labels(502, ['baz', 'foo']) if __name__ == '__main__':