diff --git a/ayon_api/graphql.py b/ayon_api/graphql.py index bd6a64efe..771752364 100644 --- a/ayon_api/graphql.py +++ b/ayon_api/graphql.py @@ -1,15 +1,23 @@ +from __future__ import annotations + import copy import numbers from abc import ABC, abstractmethod -from typing import Optional, Iterable +import typing +from typing import Optional, Iterable, Any, Generator from .exceptions import GraphQlQueryFailed from .utils import SortOrder +if typing.TYPE_CHECKING: + from typing import Union + + from .server_api import ServerAPI + FIELD_VALUE = object() -def fields_to_dict(fields): +def fields_to_dict(fields: Optional[Iterable[str]]) -> dict: output = {} if not fields: return output @@ -31,7 +39,7 @@ def fields_to_dict(fields): return output -class QueryVariable(object): +class QueryVariable: """Object representing single varible used in GraphQlQuery. Variable definition is in GraphQl query header but it's value is used @@ -41,28 +49,27 @@ class QueryVariable(object): variable_name (str): Name of variable in query. """ - - def __init__(self, variable_name): + def __init__(self, variable_name: str) -> None: self._variable_name = variable_name - self._name = "${}".format(variable_name) + self._name = f"${variable_name}" @property - def name(self): + def name(self) -> str: """Name used in field filter.""" return self._name @property - def variable_name(self): + def variable_name(self) -> str: """Name of variable in query definition.""" return self._variable_name def __hash__(self): return self._name.__hash__() - def __str__(self): + def __str__(self) -> str: return self._name - def __format__(self, *args, **kwargs): + def __format__(self, *args, **kwargs) -> str: return self._name.__format__(*args, **kwargs) @@ -78,7 +85,7 @@ class GraphQlQuery: """ offset = 2 - def __init__(self, name, order=None): + def __init__(self, name: str, order: Optional[int] = None) -> None: self._name = name self._variables = {} self._children = [] @@ -86,7 +93,7 @@ def __init__(self, name, order=None): self._order = SortOrder.parse_value(order, SortOrder.ascending) @property - def indent(self): + def indent(self) -> int: """Indentation for preparation of query string. Returns: @@ -96,7 +103,7 @@ def indent(self): return 0 @property - def child_indent(self): + def child_indent(self) -> int: """Indentation for preparation of query string used by children. Returns: @@ -106,7 +113,7 @@ def child_indent(self): return self.indent @property - def need_query(self): + def need_query(self) -> bool: """Still need query from server. Needed for edges which use pagination. @@ -121,7 +128,7 @@ def need_query(self): return False @property - def has_multiple_edge_fields(self): + def has_multiple_edge_fields(self) -> bool: if self._has_multiple_edge_fields is None: edge_counter = 0 for child in self._children: @@ -132,7 +139,9 @@ def has_multiple_edge_fields(self): return self._has_multiple_edge_fields - def add_variable(self, key, value_type, value=None): + def add_variable( + self, key: str, value_type: str, value: Optional[Any] = None + ) -> QueryVariable: """Add variable to query. Args: @@ -163,7 +172,7 @@ def add_variable(self, key, value_type, value=None): } return variable - def get_variable(self, key): + def get_variable(self, key: str) -> QueryVariable: """Variable object. Args: @@ -175,7 +184,9 @@ def get_variable(self, key): """ return self._variables[key]["variable"] - def get_variable_value(self, key, default=None): + def get_variable_value( + self, key: str, default: Optional[Any] = None + ) -> Any: """Get Current value of variable. Args: @@ -191,7 +202,7 @@ def get_variable_value(self, key, default=None): return variable_item["value"] return default - def set_variable_value(self, key, value): + def set_variable_value(self, key: str, value: Any) -> None: """Set value for variable. Args: @@ -201,7 +212,7 @@ def set_variable_value(self, key, value): """ self._variables[key]["value"] = value - def get_variable_keys(self): + def get_variable_keys(self) -> set[str]: """Get all variable keys. Returns: @@ -210,13 +221,13 @@ def get_variable_keys(self): """ return set(self._variables.keys()) - def get_variables_values(self): + def get_variables_values(self) -> dict[str, Any]: """Calculate variable values used that should be used in query. Variables with value set to 'None' are skipped. Returns: - Dict[str, Any]: Variable values by their name. + dict[str, Any]: Variable values by their name. """ output = {} @@ -227,7 +238,7 @@ def get_variables_values(self): return output - def add_obj_field(self, field): + def add_obj_field(self, field: BaseGraphQlQueryField) -> None: """Add field object to children. Args: @@ -240,7 +251,7 @@ def add_obj_field(self, field): self._children.append(field) field.set_parent(self) - def add_field_with_edges(self, name): + def add_field_with_edges(self, name: str) -> GraphQlQueryEdgeField: """Add field with edges to query. Args: @@ -254,7 +265,7 @@ def add_field_with_edges(self, name): self.add_obj_field(item) return item - def add_field(self, name): + def add_field(self, name: str) -> GraphQlQueryField: """Add field to query. Args: @@ -270,7 +281,7 @@ def add_field(self, name): def get_field_by_keys( self, keys: Iterable[str] - ) -> Optional["BaseGraphQlQueryField"]: + ) -> Optional[BaseGraphQlQueryField]: keys = list(keys) if not keys: return None @@ -283,10 +294,10 @@ def get_field_by_keys( def get_field_by_path( self, path: str - ) -> Optional["BaseGraphQlQueryField"]: + ) -> Optional[BaseGraphQlQueryField]: return self.get_field_by_keys(path.split("/")) - def calculate_query(self): + def calculate_query(self) -> str: """Calculate query string which is sent to server. Returns: @@ -304,14 +315,12 @@ def calculate_query(self): if item["value"] is None: continue - variables.append( - "{}: {}".format(item["variable"], item["type"]) - ) + variables.append(f"{item['variable']}: {item['type']}") variables_str = "" if variables: - variables_str = "({})".format(",".join(variables)) - header = "query {}{}".format(self._name, variables_str) + variables_str = f"({','.join(variables)})" + header = f"query {self._name}{variables_str}" output = [] output.append(header + " {") @@ -321,16 +330,21 @@ def calculate_query(self): return "\n".join(output) - def parse_result(self, data, output, progress_data): + def parse_result( + self, + data: dict[str, Any], + output: dict[str, Any], + progress_data: dict[str, Any], + ) -> None: """Parse data from response for output. Output is stored to passed 'output' variable. That's because of paging during which objects must have access to both new and previous values. Args: - data (Dict[str, Any]): Data received using calculated query. - output (Dict[str, Any]): Where parsed data are stored. - progress_data (Dict[str, Any]): Data used for paging. + data (dict[str, Any]): Data received using calculated query. + output (dict[str, Any]): Where parsed data are stored. + progress_data (dict[str, Any]): Data used for paging. """ if not data: @@ -339,14 +353,14 @@ def parse_result(self, data, output, progress_data): for child in self._children: child.parse_result(data, output, progress_data) - def query(self, con): + def query(self, con: ServerAPI) -> dict[str, Any]: """Do a query from server. Args: con (ServerAPI): Connection to server with 'query' method. Returns: - Dict[str, Any]: Parsed output from GraphQl query. + dict[str, Any]: Parsed output from GraphQl query. """ progress_data = {} @@ -364,14 +378,16 @@ def query(self, con): return output - def continuous_query(self, con): + def continuous_query( + self, con: ServerAPI + ) -> Generator[dict[str, Any], None, None]: """Do a query from server. Args: con (ServerAPI): Connection to server with 'query' method. Returns: - Dict[str, Any]: Parsed output from GraphQl query. + dict[str, Any]: Parsed output from GraphQl query. """ progress_data = {} @@ -414,7 +430,12 @@ class BaseGraphQlQueryField(ABC): field. """ - def __init__(self, name, parent, order): + def __init__( + self, + name: str, + parent: Union[BaseGraphQlQueryField, GraphQlQuery], + order: SortOrder, + ): if isinstance(parent, GraphQlQuery): query_item = parent else: @@ -438,14 +459,16 @@ def __init__(self, name, parent, order): self._fetched_counter = 0 def __repr__(self): - return "<{} {}>".format(self.__class__.__name__, self.path) + return f"<{self.__class__.__name__} {self.path}>" def get_name(self) -> str: return self._name name = property(get_name) - def get_field_by_keys(self, keys: Iterable[str]): + def get_field_by_keys( + self, keys: Iterable[str] + ) -> Optional[BaseGraphQlQueryField]: keys = list(keys) if not keys: return self @@ -456,10 +479,10 @@ def get_field_by_keys(self, keys: Iterable[str]): return child.get_field_by_keys(keys) return None - def set_limit(self, limit: Optional[int]): + def set_limit(self, limit: Optional[int]) -> None: self._limit = limit - def set_order(self, order): + def set_order(self, order: SortOrder) -> None: order = SortOrder.parse_value(order) if order is None: raise ValueError( @@ -468,15 +491,20 @@ def set_order(self, order): ) self._order = order - def set_ascending_order(self, enabled=True): + def set_ascending_order(self, enabled: bool = True) -> None: self.set_order( SortOrder.ascending if enabled else SortOrder.descending ) - def set_descending_order(self, enabled=True): + def set_descending_order(self, enabled: bool = True) -> None: self.set_ascending_order(not enabled) - def add_variable(self, key, value_type, value=None): + def add_variable( + self, + key: str, + value_type: str, + value: Optional[Any] = None, + ) -> QueryVariable: """Add variable to query. Args: @@ -494,7 +522,7 @@ def add_variable(self, key, value_type, value=None): """ return self._parent.add_variable(key, value_type, value) - def get_variable(self, key): + def get_variable(self, key: str) -> QueryVariable: """Variable object. Args: @@ -507,7 +535,7 @@ def get_variable(self, key): return self._parent.get_variable(key) @property - def need_query(self): + def need_query(self) -> bool: """Still need query from server. Needed for edges which use pagination. Look into children values too. @@ -524,7 +552,7 @@ def need_query(self): return True return False - def _children_iter(self): + def _children_iter(self) -> Generator[BaseGraphQlQueryField, None, None]: """Iterate over all children fields of object. Returns: @@ -534,7 +562,7 @@ def _children_iter(self): for child in self._children: yield child - def sum_edge_fields(self, max_limit=None): + def sum_edge_fields(self, max_limit: Optional[int] = None) -> int: """Check how many edge fields query has. In case there are multiple edge fields or are nested the query can't @@ -559,36 +587,36 @@ def sum_edge_fields(self, max_limit=None): return counter @property - def offset(self): + def offset(self) -> int: return self._query_item.offset @property - def indent(self): + def indent(self) -> int: return self._parent.child_indent + self.offset @property @abstractmethod - def child_indent(self): + def child_indent(self) -> int: pass @property - def query_item(self): + def query_item(self) -> GraphQlQuery: return self._query_item @property @abstractmethod - def has_edges(self): + def has_edges(self) -> bool: pass @property - def child_has_edges(self): + def child_has_edges(self) -> bool: for child in self._children_iter(): if child.has_edges or child.child_has_edges: return True return False @property - def path(self): + def path(self) -> str: """Field path for debugging purposes. Returns: @@ -603,49 +631,53 @@ def path(self): self._path = path return self._path - def reset_cursor(self): + def reset_cursor(self) -> None: for child in self._children_iter(): child.reset_cursor() - def get_variable_value(self, *args, **kwargs): - return self._query_item.get_variable_value(*args, **kwargs) + def get_variable_value( + self, key: str, default: Optional[Any] = None + ) -> Any: + return self._query_item.get_variable_value(key, default) - def set_variable_value(self, *args, **kwargs): - return self._query_item.set_variable_value(*args, **kwargs) + def set_variable_value(self, key: str, value: Any) -> None: + self._query_item.set_variable_value(key, value) - def set_filter(self, key, value): + def set_filter(self, key: str, value: Any) -> None: self._filters[key] = value - def has_filter(self, key): + def has_filter(self, key: str) -> bool: return key in self._filters - def remove_filter(self, key): + def remove_filter(self, key: str) -> None: self._filters.pop(key, None) - def set_parent(self, parent): + def set_parent( + self, parent: Union[BaseGraphQlQueryField, GraphQlQuery] + ) -> None: if self._parent is parent: return self._parent = parent parent.add_obj_field(self) - def add_obj_field(self, field): + def add_obj_field(self, field: BaseGraphQlQueryField) -> None: if field in self._children: return self._children.append(field) field.set_parent(self) - def add_field_with_edges(self, name): + def add_field_with_edges(self, name: str) -> GraphQlQueryEdgeField: item = GraphQlQueryEdgeField(name, self, self._order) self.add_obj_field(item) return item - def add_field(self, name): + def add_field(self, name: str) -> GraphQlQueryField: item = GraphQlQueryField(name, self, self._order) self.add_obj_field(item) return item - def _filter_value_to_str(self, value): + def _filter_value_to_str(self, value: Any) -> Optional[str]: if isinstance(value, QueryVariable): if self.get_variable_value(value.variable_name) is None: return None @@ -655,31 +687,31 @@ def _filter_value_to_str(self, value): return str(value) if isinstance(value, str): - return '"{}"'.format(value) + return f'"{value}"' if isinstance(value, (list, set, tuple)): - return "[{}]".format( - ", ".join( - self._filter_value_to_str(item) - for item in iter(value) - ) + joined_values = ", ".join( + self._filter_value_to_str(item) + for item in iter(value) ) + return f"[{joined_values}]" + raise TypeError( "Unknown type to convert '{}'".format(str(type(value))) ) - def get_filters(self): + def get_filters(self) -> dict[str, Any]: """Receive filters for item. By default just use copy of set filters. Returns: - Dict[str, Any]: Fields filters. + dict[str, Any]: Fields filters. """ return copy.deepcopy(self._filters) - def _filters_to_string(self): + def _filters_to_string(self) -> str: filters = self.get_filters() if not filters: return "" @@ -690,23 +722,29 @@ def _filters_to_string(self): if string_value is None: continue - filter_items.append("{}: {}".format(key, string_value)) + filter_items.append(f"{key}: {string_value}") if not filter_items: return "" - return "({})".format(", ".join(filter_items)) + joined_items = ", ".join(filter_items) + return f"({joined_items})" - def _fake_children_parse(self): + def _fake_children_parse(self) -> None: """Mark children as they don't need query.""" for child in self._children_iter(): child.parse_result({}, {}, {}) @abstractmethod - def calculate_query(self): + def calculate_query(self) -> str: pass @abstractmethod - def parse_result(self, data, output, progress_data): + def parse_result( + self, + data: dict[str, Any], + output: dict[str, Any], + progress_data: dict[str, Any], + ) -> None: pass @@ -714,14 +752,19 @@ class GraphQlQueryField(BaseGraphQlQueryField): has_edges = False @property - def child_indent(self): + def child_indent(self) -> int: return self.indent - def parse_result(self, data, output, progress_data): + def parse_result( + self, + data: dict[str, Any], + output: dict[str, Any], + progress_data: dict[str, Any], + ) -> None: if not isinstance(data, dict): - raise TypeError("{} Expected 'dict' type got '{}'".format( - self._name, str(type(data)) - )) + raise TypeError( + f"{self._name} Expected 'dict' type got '{type(data)}'" + ) self._need_query = False value = data.get(self._name) @@ -763,13 +806,9 @@ def parse_result(self, data, output, progress_data): for child in self._children: child.parse_result(item, item_value, progress_data) - def calculate_query(self): + def calculate_query(self) -> str: offset = self.indent * " " - header = "{}{}{}".format( - offset, - self._name, - self._filters_to_string() - ) + header = f"{offset}{self._name}{self._filters_to_string()}" if not self._children: return header @@ -794,43 +833,48 @@ def __init__(self, *args, **kwargs): self._edge_children = [] @property - def child_indent(self): + def child_indent(self) -> int: offset = self.offset * 2 return self.indent + offset - def _children_iter(self): + def _children_iter(self) -> Generator[BaseGraphQlQueryField, None, None]: for child in super()._children_iter(): yield child for child in self._edge_children: yield child - def add_obj_field(self, field): + def add_obj_field(self, field: BaseGraphQlQueryField) -> None: if field in self._edge_children: return super().add_obj_field(field) - def add_obj_edge_field(self, field): + def add_obj_edge_field(self, field: BaseGraphQlQueryField) -> None: if field in self._edge_children or field in self._children: return self._edge_children.append(field) field.set_parent(self) - def add_edge_field(self, name): - item = GraphQlQueryField(name, self, self._order) + def add_edge_field(self, name: str) -> GraphQlQueryEdgeField: + item = GraphQlQueryEdgeField(name, self, self._order) self.add_obj_edge_field(item) return item - def reset_cursor(self): + def reset_cursor(self) -> None: # Reset cursor only for edges self._cursor = None self._need_query = True super().reset_cursor() - def parse_result(self, data, output, progress_data): + def parse_result( + self, + data: dict[str, Any], + output: dict[str, Any], + progress_data: dict[str, Any], + ) -> None: if not isinstance(data, dict): raise TypeError("{} Expected 'dict' type got '{}'".format( self._name, str(type(data)) @@ -848,13 +892,13 @@ def parse_result(self, data, output, progress_data): node_values = [] output[self._name] = node_values + nodes_by_cursor = {} handle_cursors = self.child_has_edges if handle_cursors: cursor_key = self._get_cursor_key() if cursor_key in progress_data: nodes_by_cursor = progress_data[cursor_key] else: - nodes_by_cursor = {} progress_data[cursor_key] = nodes_by_cursor page_info = value["pageInfo"] @@ -900,10 +944,10 @@ def parse_result(self, data, output, progress_data): child.reset_cursor() self._cursor = new_cursor - def _get_cursor_key(self): - return "{}/__cursor__".format(self.path) + def _get_cursor_key(self) -> str: + return f"{self.path}/__cursor__" - def get_filters(self): + def get_filters(self) -> dict[str, Any]: filters = super().get_filters() limit_key = "first" if self._order == SortOrder.descending: @@ -921,18 +965,14 @@ def get_filters(self): filters["after"] = self._cursor return filters - def calculate_query(self): + def calculate_query(self) -> str: if not self._children and not self._edge_children: raise ValueError("Missing child definitions for edges {}".format( self.path )) offset = self.indent * " " - header = "{}{}{}".format( - offset, - self._name, - self._filters_to_string() - ) + header = f"{offset}{self._name}{self._filters_to_string()}" output = [] output.append(header + " {")