diff --git a/.gitignore b/.gitignore index 699c26d..14ec495 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ *.so build distr +docs/generated tests/config/*.xml junit*.xml pyignite.egg-info diff --git a/.travis.yml b/.travis.yml index 7e726be..74909b8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -43,6 +43,9 @@ jobs: - python: '3.8' arch: amd64 env: TOXENV=py38 + - python: '3.8' + arch: amd64 + env: TOXENV=codestyle - python: '3.9' arch: amd64 env: TOXENV=py39 diff --git a/examples/create_binary.py b/examples/create_binary.py index c963796..b199527 100644 --- a/examples/create_binary.py +++ b/examples/create_binary.py @@ -23,44 +23,44 @@ client.connect('127.0.0.1', 10800) student_cache = client.create_cache({ - PROP_NAME: 'SQL_PUBLIC_STUDENT', - PROP_SQL_SCHEMA: 'PUBLIC', - PROP_QUERY_ENTITIES: [ - { - 'table_name': 'Student'.upper(), - 'key_field_name': 'SID', - 'key_type_name': 'java.lang.Integer', - 'field_name_aliases': [], - 'query_fields': [ - { - 'name': 'SID', - 'type_name': 'java.lang.Integer', - 'is_key_field': True, - 'is_notnull_constraint_field': True, - }, - { - 'name': 'NAME', - 'type_name': 'java.lang.String', - }, - { - 'name': 'LOGIN', - 'type_name': 'java.lang.String', - }, - { - 'name': 'AGE', - 'type_name': 'java.lang.Integer', - }, - { - 'name': 'GPA', - 'type_name': 'java.math.Double', - }, - ], - 'query_indexes': [], - 'value_type_name': 'SQL_PUBLIC_STUDENT_TYPE', - 'value_field_name': None, - }, - ], - }) + PROP_NAME: 'SQL_PUBLIC_STUDENT', + PROP_SQL_SCHEMA: 'PUBLIC', + PROP_QUERY_ENTITIES: [ + { + 'table_name': 'Student'.upper(), + 'key_field_name': 'SID', + 'key_type_name': 'java.lang.Integer', + 'field_name_aliases': [], + 'query_fields': [ + { + 'name': 'SID', + 'type_name': 'java.lang.Integer', + 'is_key_field': True, + 'is_notnull_constraint_field': True, + }, + { + 'name': 'NAME', + 'type_name': 'java.lang.String', + }, + { + 'name': 'LOGIN', + 'type_name': 'java.lang.String', + }, + { + 'name': 'AGE', + 'type_name': 'java.lang.Integer', + }, + { + 'name': 'GPA', + 'type_name': 'java.math.Double', + }, + ], + 'query_indexes': [], + 'value_type_name': 'SQL_PUBLIC_STUDENT_TYPE', + 'value_field_name': None, + }, + ], +}) class Student( diff --git a/examples/sql.py b/examples/sql.py index 8f0ee7c..0e8c729 100644 --- a/examples/sql.py +++ b/examples/sql.py @@ -280,7 +280,7 @@ field_data = list(*result) print('City info:') -for field_name, field_value in zip(field_names*len(field_data), field_data): +for field_name, field_value in zip(field_names * len(field_data), field_data): print('{}: {}'.format(field_name, field_value)) # City info: # ID: 3802 diff --git a/pyignite/__init__.py b/pyignite/__init__.py index 0ac346f..c26c59a 100644 --- a/pyignite/__init__.py +++ b/pyignite/__init__.py @@ -14,4 +14,7 @@ # limitations under the License. from pyignite.client import Client +from pyignite.aio_client import AioClient from pyignite.binary import GenericObjectMeta + +__version__ = '0.4.0-dev' diff --git a/pyignite/aio_cache.py b/pyignite/aio_cache.py new file mode 100644 index 0000000..b92a14c --- /dev/null +++ b/pyignite/aio_cache.py @@ -0,0 +1,600 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +from typing import Any, Dict, Iterable, Optional, Union + +from .constants import AFFINITY_RETRIES, AFFINITY_DELAY +from .connection import AioConnection +from .datatypes import prop_codes +from .datatypes.base import IgniteDataType +from .datatypes.internal import AnyDataObject +from .exceptions import CacheCreationError, CacheError, ParameterError, connection_errors +from .utils import cache_id, status_to_exception +from .api.cache_config import ( + cache_create_async, cache_get_or_create_async, cache_destroy_async, cache_get_configuration_async, + cache_create_with_config_async, cache_get_or_create_with_config_async +) +from .api.key_value import ( + cache_get_async, cache_contains_key_async, cache_clear_key_async, cache_clear_keys_async, cache_clear_async, + cache_replace_async, cache_put_all_async, cache_get_all_async, cache_put_async, cache_contains_keys_async, + cache_get_and_put_async, cache_get_and_put_if_absent_async, cache_put_if_absent_async, cache_get_and_remove_async, + cache_get_and_replace_async, cache_remove_key_async, cache_remove_keys_async, cache_remove_all_async, + cache_remove_if_equals_async, cache_replace_if_equals_async, cache_get_size_async, +) +from .cursors import AioScanCursor +from .api.affinity import cache_get_node_partitions_async +from .cache import __parse_settings, BaseCacheMixin + + +async def get_cache(client: 'AioClient', settings: Union[str, dict]) -> 'AioCache': + name, settings = __parse_settings(settings) + if settings: + raise ParameterError('Only cache name allowed as a parameter') + + return AioCache(client, name) + + +async def create_cache(client: 'AioClient', settings: Union[str, dict]) -> 'AioCache': + name, settings = __parse_settings(settings) + + conn = await client.random_node() + if settings: + result = await cache_create_with_config_async(conn, settings) + else: + result = await cache_create_async(conn, name) + + if result.status != 0: + raise CacheCreationError(result.message) + + return AioCache(client, name) + + +async def get_or_create_cache(client: 'AioClient', settings: Union[str, dict]) -> 'AioCache': + name, settings = __parse_settings(settings) + + conn = await client.random_node() + if settings: + result = await cache_get_or_create_with_config_async(conn, settings) + else: + result = await cache_get_or_create_async(conn, name) + + if result.status != 0: + raise CacheCreationError(result.message) + + return AioCache(client, name) + + +class AioCache(BaseCacheMixin): + """ + Ignite cache abstraction. Users should never use this class directly, + but construct its instances with + :py:meth:`~pyignite.client.Client.create_cache`, + :py:meth:`~pyignite.client.Client.get_or_create_cache` or + :py:meth:`~pyignite.client.Client.get_cache` methods instead. See + :ref:`this example ` on how to do it. + """ + def __init__(self, client: 'AioClient', name: str): + """ + Initialize async cache object. For internal use. + + :param client: Async Ignite client, + :param name: Cache name. + """ + self._client = client + self._name = name + self._cache_id = cache_id(self._name) + self._settings = None + self._affinity_query_mux = asyncio.Lock() + self.affinity = {'version': (0, 0)} + + async def settings(self) -> Optional[dict]: + """ + Lazy Cache settings. See the :ref:`example ` + of reading this property. + + All cache properties are documented here: :ref:`cache_props`. + + :return: dict of cache properties and their values. + """ + if self._settings is None: + conn = await self.get_best_node() + config_result = await cache_get_configuration_async(conn, self._cache_id) + + if config_result.status == 0: + self._settings = config_result.value + else: + raise CacheError(config_result.message) + + return self._settings + + async def name(self) -> str: + """ + Lazy cache name. + + :return: cache name string. + """ + if self._name is None: + settings = await self.settings() + self._name = settings[prop_codes.PROP_NAME] + + return self._name + + @property + def client(self) -> 'AioClient': + """ + Ignite :class:`~pyignite.aio_client.AioClient` object. + + :return: Async client object, through which the cache is accessed. + """ + return self._client + + @property + def cache_id(self) -> int: + """ + Cache ID. + + :return: integer value of the cache ID. + """ + return self._cache_id + + @status_to_exception(CacheError) + async def destroy(self): + """ + Destroys cache with a given name. + """ + conn = await self.get_best_node() + return await cache_destroy_async(conn, self._cache_id) + + @status_to_exception(CacheError) + async def _get_affinity(self, conn: 'AioConnection') -> Dict: + """ + Queries server for affinity mappings. Retries in case + of an intermittent error (most probably “Getting affinity for topology + version earlier than affinity is calculated”). + + :param conn: connection to Igneite server, + :return: OP_CACHE_PARTITIONS operation result value. + """ + for _ in range(AFFINITY_RETRIES or 1): + result = await cache_get_node_partitions_async(conn, self._cache_id) + if result.status == 0 and result.value['partition_mapping']: + break + await asyncio.sleep(AFFINITY_DELAY) + + return result + + async def get_best_node(self, key: Any = None, key_hint: 'IgniteDataType' = None) -> 'AioConnection': + """ + Returns the node from the list of the nodes, opened by client, that + most probably contains the needed key-value pair. See IEP-23. + + This method is not a part of the public API. Unless you wish to + extend the `pyignite` capabilities (with additional testing, logging, + examining connections, et c.) you probably should not use it. + + :param key: (optional) pythonic key, + :param key_hint: (optional) Ignite data type, for which the given key + should be converted, + :return: Ignite connection object. + """ + conn = await self._client.random_node() + + if self.client.partition_aware and key is not None: + if self.__should_update_mapping(): + async with self._affinity_query_mux: + while self.__should_update_mapping(): + try: + full_affinity = await self._get_affinity(conn) + self._update_affinity(full_affinity) + + asyncio.ensure_future( + asyncio.gather( + *[conn.reconnect() for conn in self.client._nodes if not conn.alive], + return_exceptions=True + ) + ) + + break + except connection_errors: + # retry if connection failed + conn = await self._client.random_node() + pass + except CacheError: + # server did not create mapping in time + return conn + + parts = self.affinity.get('number_of_partitions') + + if not parts: + return conn + + key, key_hint = self._get_affinity_key(key, key_hint) + + hashcode = await key_hint.hashcode_async(key, self._client) + + best_node = self._get_node_by_hashcode(hashcode, parts) + if best_node: + return best_node + + return conn + + def __should_update_mapping(self): + return self.affinity['version'] < self._client.affinity_version + + @status_to_exception(CacheError) + async def get(self, key, key_hint: object = None) -> Any: + """ + Retrieves a value from cache by key. + + :param key: key for the cache entry. Can be of any supported type, + :param key_hint: (optional) Ignite data type, for which the given key + should be converted, + :return: value retrieved. + """ + if key_hint is None: + key_hint = AnyDataObject.map_python_type(key) + + conn = await self.get_best_node(key, key_hint) + result = await cache_get_async(conn, self._cache_id, key, key_hint=key_hint) + result.value = await self.client.unwrap_binary(result.value) + return result + + @status_to_exception(CacheError) + async def put(self, key, value, key_hint: object = None, value_hint: object = None): + """ + Puts a value with a given key to cache (overwriting existing value + if any). + + :param key: key for the cache entry. Can be of any supported type, + :param value: value for the key, + :param key_hint: (optional) Ignite data type, for which the given key + should be converted, + :param value_hint: (optional) Ignite data type, for which the given + value should be converted. + """ + if key_hint is None: + key_hint = AnyDataObject.map_python_type(key) + + conn = await self.get_best_node(key, key_hint) + return await cache_put_async(conn, self._cache_id, key, value, key_hint=key_hint, value_hint=value_hint) + + @status_to_exception(CacheError) + async def get_all(self, keys: list) -> list: + """ + Retrieves multiple key-value pairs from cache. + + :param keys: list of keys or tuples of (key, key_hint), + :return: a dict of key-value pairs. + """ + conn = await self.get_best_node() + result = await cache_get_all_async(conn, self._cache_id, keys) + if result.value: + keys = list(result.value.keys()) + values = await asyncio.gather(*[self.client.unwrap_binary(value) for value in result.value.values()]) + + for i, key in enumerate(keys): + result.value[key] = values[i] + return result + + @status_to_exception(CacheError) + async def put_all(self, pairs: dict): + """ + Puts multiple key-value pairs to cache (overwriting existing + associations if any). + + :param pairs: dictionary type parameters, contains key-value pairs + to save. Each key or value can be an item of representable + Python type or a tuple of (item, hint), + """ + conn = await self.get_best_node() + return await cache_put_all_async(conn, self._cache_id, pairs) + + @status_to_exception(CacheError) + async def replace(self, key, value, key_hint: object = None, value_hint: object = None): + """ + Puts a value with a given key to cache only if the key already exist. + + :param key: key for the cache entry. Can be of any supported type, + :param value: value for the key, + :param key_hint: (optional) Ignite data type, for which the given key + should be converted, + :param value_hint: (optional) Ignite data type, for which the given + value should be converted. + """ + if key_hint is None: + key_hint = AnyDataObject.map_python_type(key) + + conn = await self.get_best_node(key, key_hint) + result = await cache_replace_async(conn, self._cache_id, key, value, key_hint=key_hint, value_hint=value_hint) + result.value = await self.client.unwrap_binary(result.value) + return result + + @status_to_exception(CacheError) + async def clear(self, keys: Optional[list] = None): + """ + Clears the cache without notifying listeners or cache writers. + + :param keys: (optional) list of cache keys or (key, key type + hint) tuples to clear (default: clear all). + """ + conn = await self.get_best_node() + if keys: + return await cache_clear_keys_async(conn, self._cache_id, keys) + else: + return await cache_clear_async(conn, self._cache_id) + + @status_to_exception(CacheError) + async def clear_key(self, key, key_hint: object = None): + """ + Clears the cache key without notifying listeners or cache writers. + + :param key: key for the cache entry, + :param key_hint: (optional) Ignite data type, for which the given key + should be converted, + """ + if key_hint is None: + key_hint = AnyDataObject.map_python_type(key) + + conn = await self.get_best_node(key, key_hint) + return await cache_clear_key_async(conn, self._cache_id, key, key_hint=key_hint) + + @status_to_exception(CacheError) + async def clear_keys(self, keys: Iterable): + """ + Clears the cache key without notifying listeners or cache writers. + + :param keys: a list of keys or (key, type hint) tuples + """ + conn = await self.get_best_node() + return await cache_clear_keys_async(conn, self._cache_id, keys) + + @status_to_exception(CacheError) + async def contains_key(self, key, key_hint=None) -> bool: + """ + Returns a value indicating whether given key is present in cache. + + :param key: key for the cache entry. Can be of any supported type, + :param key_hint: (optional) Ignite data type, for which the given key + should be converted, + :return: boolean `True` when key is present, `False` otherwise. + """ + if key_hint is None: + key_hint = AnyDataObject.map_python_type(key) + + conn = await self.get_best_node(key, key_hint) + return await cache_contains_key_async(conn, self._cache_id, key, key_hint=key_hint) + + @status_to_exception(CacheError) + async def contains_keys(self, keys: Iterable) -> bool: + """ + Returns a value indicating whether all given keys are present in cache. + + :param keys: a list of keys or (key, type hint) tuples, + :return: boolean `True` when all keys are present, `False` otherwise. + """ + conn = await self.get_best_node() + return await cache_contains_keys_async(conn, self._cache_id, keys) + + @status_to_exception(CacheError) + async def get_and_put(self, key, value, key_hint=None, value_hint=None) -> Any: + """ + Puts a value with a given key to cache, and returns the previous value + for that key, or null value if there was not such key. + + :param key: key for the cache entry. Can be of any supported type, + :param value: value for the key, + :param key_hint: (optional) Ignite data type, for which the given key + should be converted, + :param value_hint: (optional) Ignite data type, for which the given + value should be converted. + :return: old value or None. + """ + if key_hint is None: + key_hint = AnyDataObject.map_python_type(key) + + conn = await self.get_best_node(key, key_hint) + result = await cache_get_and_put_async(conn, self._cache_id, key, value, key_hint, value_hint) + + result.value = await self.client.unwrap_binary(result.value) + return result + + @status_to_exception(CacheError) + async def get_and_put_if_absent(self, key, value, key_hint=None, value_hint=None): + """ + Puts a value with a given key to cache only if the key does not + already exist. + + :param key: key for the cache entry. Can be of any supported type, + :param value: value for the key, + :param key_hint: (optional) Ignite data type, for which the given key + should be converted, + :param value_hint: (optional) Ignite data type, for which the given + value should be converted, + :return: old value or None. + """ + if key_hint is None: + key_hint = AnyDataObject.map_python_type(key) + + conn = await self.get_best_node(key, key_hint) + result = await cache_get_and_put_if_absent_async(conn, self._cache_id, key, value, key_hint, value_hint) + result.value = await self.client.unwrap_binary(result.value) + return result + + @status_to_exception(CacheError) + async def put_if_absent(self, key, value, key_hint=None, value_hint=None): + """ + Puts a value with a given key to cache only if the key does not + already exist. + + :param key: key for the cache entry. Can be of any supported type, + :param value: value for the key, + :param key_hint: (optional) Ignite data type, for which the given key + should be converted, + :param value_hint: (optional) Ignite data type, for which the given + value should be converted. + """ + if key_hint is None: + key_hint = AnyDataObject.map_python_type(key) + + conn = await self.get_best_node(key, key_hint) + return await cache_put_if_absent_async(conn, self._cache_id, key, value, key_hint, value_hint) + + @status_to_exception(CacheError) + async def get_and_remove(self, key, key_hint=None) -> Any: + """ + Removes the cache entry with specified key, returning the value. + + :param key: key for the cache entry. Can be of any supported type, + :param key_hint: (optional) Ignite data type, for which the given key + should be converted, + :return: old value or None. + """ + if key_hint is None: + key_hint = AnyDataObject.map_python_type(key) + + conn = await self.get_best_node(key, key_hint) + result = await cache_get_and_remove_async(conn, self._cache_id, key, key_hint) + result.value = await self.client.unwrap_binary(result.value) + return result + + @status_to_exception(CacheError) + async def get_and_replace(self, key, value, key_hint=None, value_hint=None) -> Any: + """ + Puts a value with a given key to cache, returning previous value + for that key, if and only if there is a value currently mapped + for that key. + + :param key: key for the cache entry. Can be of any supported type, + :param value: value for the key, + :param key_hint: (optional) Ignite data type, for which the given key + should be converted, + :param value_hint: (optional) Ignite data type, for which the given + value should be converted. + :return: old value or None. + """ + if key_hint is None: + key_hint = AnyDataObject.map_python_type(key) + + conn = await self.get_best_node(key, key_hint) + result = await cache_get_and_replace_async(conn, self._cache_id, key, value, key_hint, value_hint) + result.value = await self.client.unwrap_binary(result.value) + return result + + @status_to_exception(CacheError) + async def remove_key(self, key, key_hint=None): + """ + Clears the cache key without notifying listeners or cache writers. + + :param key: key for the cache entry, + :param key_hint: (optional) Ignite data type, for which the given key + should be converted, + """ + if key_hint is None: + key_hint = AnyDataObject.map_python_type(key) + + conn = await self.get_best_node(key, key_hint) + return await cache_remove_key_async(conn, self._cache_id, key, key_hint) + + @status_to_exception(CacheError) + async def remove_keys(self, keys: list): + """ + Removes cache entries by given list of keys, notifying listeners + and cache writers. + + :param keys: list of keys or tuples of (key, key_hint) to remove. + """ + conn = await self.get_best_node() + return await cache_remove_keys_async(conn, self._cache_id, keys) + + @status_to_exception(CacheError) + async def remove_all(self): + """ + Removes all cache entries, notifying listeners and cache writers. + """ + conn = await self.get_best_node() + return await cache_remove_all_async(conn, self._cache_id) + + @status_to_exception(CacheError) + async def remove_if_equals(self, key, sample, key_hint=None, sample_hint=None): + """ + Removes an entry with a given key if provided value is equal to + actual value, notifying listeners and cache writers. + + :param key: key for the cache entry, + :param sample: a sample to compare the stored value with, + :param key_hint: (optional) Ignite data type, for which the given key + should be converted, + :param sample_hint: (optional) Ignite data type, for whic + the given sample should be converted. + """ + if key_hint is None: + key_hint = AnyDataObject.map_python_type(key) + + conn = await self.get_best_node(key, key_hint) + return await cache_remove_if_equals_async(conn, self._cache_id, key, sample, key_hint, sample_hint) + + @status_to_exception(CacheError) + async def replace_if_equals(self, key, sample, value, key_hint=None, sample_hint=None, value_hint=None) -> Any: + """ + Puts a value with a given key to cache only if the key already exists + and value equals provided sample. + + :param key: key for the cache entry, + :param sample: a sample to compare the stored value with, + :param value: new value for the given key, + :param key_hint: (optional) Ignite data type, for which the given key + should be converted, + :param sample_hint: (optional) Ignite data type, for whic + the given sample should be converted + :param value_hint: (optional) Ignite data type, for which the given + value should be converted, + :return: boolean `True` when key is present, `False` otherwise. + """ + if key_hint is None: + key_hint = AnyDataObject.map_python_type(key) + + conn = await self.get_best_node(key, key_hint) + result = await cache_replace_if_equals_async(conn, self._cache_id, key, sample, value, key_hint, sample_hint, + value_hint) + result.value = await self.client.unwrap_binary(result.value) + return result + + @status_to_exception(CacheError) + async def get_size(self, peek_modes=0): + """ + Gets the number of entries in cache. + + :param peek_modes: (optional) limit count to near cache partition + (PeekModes.NEAR), primary cache (PeekModes.PRIMARY), or backup cache + (PeekModes.BACKUP). Defaults to all cache partitions (PeekModes.ALL), + :return: integer number of cache entries. + """ + conn = await self.get_best_node() + return await cache_get_size_async(conn, self._cache_id, peek_modes) + + def scan(self, page_size: int = 1, partitions: int = -1, local: bool = False): + """ + Returns all key-value pairs from the cache, similar to `get_all`, but + with internal pagination, which is slower, but safer. + + :param page_size: (optional) page size. Default size is 1 (slowest + and safest), + :param partitions: (optional) number of partitions to query + (negative to query entire cache), + :param local: (optional) pass True if this query should be executed + on local node only. Defaults to False, + :return: async scan query cursor + """ + return AioScanCursor(self.client, self._cache_id, page_size, partitions, local) diff --git a/pyignite/aio_client.py b/pyignite/aio_client.py new file mode 100644 index 0000000..d882969 --- /dev/null +++ b/pyignite/aio_client.py @@ -0,0 +1,358 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import random +from itertools import chain +from typing import Iterable, Type, Union, Any + +from .api.binary import get_binary_type_async, put_binary_type_async +from .api.cache_config import cache_get_names_async +from .client import BaseClient +from .cursors import AioSqlFieldsCursor +from .aio_cache import AioCache, get_cache, create_cache, get_or_create_cache +from .connection import AioConnection +from .constants import IGNITE_DEFAULT_HOST, IGNITE_DEFAULT_PORT +from .datatypes import BinaryObject +from .exceptions import BinaryTypeError, CacheError, ReconnectError, connection_errors +from .stream import AioBinaryStream, READ_BACKWARD +from .utils import cache_id, entity_id, status_to_exception, is_iterable, is_wrapped + + +__all__ = ['AioClient'] + + +class AioClient(BaseClient): + """ + Asynchronous Client implementation. + """ + + def __init__(self, compact_footer: bool = None, partition_aware: bool = False, **kwargs): + """ + Initialize client. + + :param compact_footer: (optional) use compact (True, recommended) or + full (False) schema approach when serializing Complex objects. + Default is to use the same approach the server is using (None). + Apache Ignite binary protocol documentation on this topic: + https://apacheignite.readme.io/docs/binary-client-protocol-data-format#section-schema + :param partition_aware: (optional) try to calculate the exact data + placement from the key before to issue the key operation to the + server node: + https://cwiki.apache.org/confluence/display/IGNITE/IEP-23%3A+Best+Effort+Affinity+for+thin+clients + The feature is in experimental status, so the parameter is `False` + by default. This will be changed later. + """ + super().__init__(compact_footer, partition_aware, **kwargs) + self._registry_mux = asyncio.Lock() + + async def connect(self, *args): + """ + Connect to Ignite cluster node(s). + + :param args: (optional) host(s) and port(s) to connect to. + """ + nodes = self._process_connect_args(*args) + + for i, node in enumerate(nodes): + host, port = node + conn = AioConnection(self, host, port, **self._connection_args) + + if not self.partition_aware: + try: + if self.protocol_version is None: + # open connection before adding to the pool + await conn.connect() + + # do not try to open more nodes + self._current_node = i + + except connection_errors: + conn.failed = True + + self._nodes.append(conn) + + if self.partition_aware: + connect_results = await asyncio.gather( + *[conn.connect() for conn in self._nodes], + return_exceptions=True + ) + + reconnect_coro = [] + for i, res in enumerate(connect_results): + if isinstance(res, Exception): + if isinstance(res, connection_errors): + reconnect_coro.append(self._nodes[i].reconnect()) + else: + raise res + + await asyncio.gather(*reconnect_coro, return_exceptions=True) + + if self.protocol_version is None: + raise ReconnectError('Can not connect.') + + async def close(self): + await asyncio.gather(*[conn.close() for conn in self._nodes], return_exceptions=True) + self._nodes.clear() + + async def random_node(self) -> AioConnection: + """ + Returns random usable node. + + This method is not a part of the public API. Unless you wish to + extend the `pyignite` capabilities (with additional testing, logging, + examining connections, et c.) you probably should not use it. + """ + if self.partition_aware: + # if partition awareness is used just pick a random connected node + return await self._get_random_node() + else: + # if partition awareness is not used then just return the current + # node if it's alive or the next usable node if connection with the + # current is broken + node = self._nodes[self._current_node] + if node.alive: + return node + + # close current (supposedly failed) node + await self._nodes[self._current_node].close() + + # advance the node index + self._current_node += 1 + if self._current_node >= len(self._nodes): + self._current_node = 0 + + # prepare the list of node indexes to try to connect to + for i in chain(range(self._current_node, len(self._nodes)), range(self._current_node)): + node = self._nodes[i] + try: + await node.connect() + except connection_errors: + pass + else: + return node + + # no nodes left + raise ReconnectError('Can not reconnect: out of nodes.') + + async def _get_random_node(self, reconnect=True): + alive_nodes = [n for n in self._nodes if n.alive] + if alive_nodes: + return random.choice(alive_nodes) + elif reconnect: + await asyncio.gather(*[n.reconnect() for n in self._nodes], return_exceptions=True) + return await self._get_random_node(reconnect=False) + else: + # cannot choose from an empty sequence + raise ReconnectError('Can not reconnect: out of nodes.') from None + + @status_to_exception(BinaryTypeError) + async def get_binary_type(self, binary_type: Union[str, int]) -> dict: + """ + Gets the binary type information from the Ignite server. This is quite + a low-level implementation of Ignite thin client protocol's + `OP_GET_BINARY_TYPE` operation. You would probably want to use + :py:meth:`~pyignite.client.Client.query_binary_type` instead. + + :param binary_type: binary type name or ID, + :return: binary type description − a dict with the following fields: + + - `type_exists`: True if the type is registered, False otherwise. In + the latter case all the following fields are omitted, + - `type_id`: Complex object type ID, + - `type_name`: Complex object type name, + - `affinity_key_field`: string value or None, + - `is_enum`: False in case of Complex object registration, + - `schemas`: a list, containing the Complex object schemas in format: + OrderedDict[field name: field type hint]. A schema can be empty. + """ + conn = await self.random_node() + result = await get_binary_type_async(conn, binary_type) + return self._process_get_binary_type_result(result) + + @status_to_exception(BinaryTypeError) + async def put_binary_type(self, type_name: str, affinity_key_field: str = None, is_enum=False, schema: dict = None): + """ + Registers binary type information in cluster. Do not update binary + registry. This is a literal implementation of Ignite thin client + protocol's `OP_PUT_BINARY_TYPE` operation. You would probably want + to use :py:meth:`~pyignite.client.Client.register_binary_type` instead. + + :param type_name: name of the data type being registered, + :param affinity_key_field: (optional) name of the affinity key field, + :param is_enum: (optional) register enum if True, binary object + otherwise. Defaults to False, + :param schema: (optional) when register enum, pass a dict + of enumerated parameter names as keys and an integers as values. + When register binary type, pass a dict of field names: field types. + Binary type with no fields is OK. + """ + conn = await self.random_node() + return await put_binary_type_async(conn, type_name, affinity_key_field, is_enum, schema) + + async def register_binary_type(self, data_class: Type, affinity_key_field: str = None): + """ + Register the given class as a representation of a certain Complex + object type. Discards autogenerated or previously registered class. + + :param data_class: Complex object class, + :param affinity_key_field: (optional) affinity parameter. + """ + if not await self.query_binary_type(data_class.type_id, data_class.schema_id): + await self.put_binary_type(data_class.type_name, affinity_key_field, schema=data_class.schema) + + self._registry[data_class.type_id][data_class.schema_id] = data_class + + async def query_binary_type(self, binary_type: Union[int, str], schema: Union[int, dict] = None): + """ + Queries the registry of Complex object classes. + + :param binary_type: Complex object type name or ID, + :param schema: (optional) Complex object schema or schema ID, + :return: found dataclass or None, if `schema` parameter is provided, + a dict of {schema ID: dataclass} format otherwise. + """ + type_id = entity_id(binary_type) + + result = self._get_from_registry(type_id, schema) + + if not result: + async with self._registry_mux: + result = self._get_from_registry(type_id, schema) + + if not result: + type_info = await self.get_binary_type(type_id) + self._sync_binary_registry(type_id, type_info) + return self._get_from_registry(type_id, schema) + + return result + + async def unwrap_binary(self, value: Any) -> Any: + """ + Detects and recursively unwraps Binary Object. + + :param value: anything that could be a Binary Object, + :return: the result of the Binary Object unwrapping with all other data + left intact. + """ + if is_wrapped(value): + blob, offset = value + with AioBinaryStream(self, blob) as stream: + data_class = await BinaryObject.parse_async(stream) + return await BinaryObject.to_python_async(stream.read_ctype(data_class, direction=READ_BACKWARD), self) + return value + + async def create_cache(self, settings: Union[str, dict]) -> 'AioCache': + """ + Creates Ignite cache by name. Raises `CacheError` if such a cache is + already exists. + + :param settings: cache name or dict of cache properties' codes + and values. All cache properties are documented here: + :ref:`cache_props`. See also the + :ref:`cache creation example `, + :return: :class:`~pyignite.cache.Cache` object. + """ + return await create_cache(self, settings) + + async def get_or_create_cache(self, settings: Union[str, dict]) -> 'AioCache': + """ + Creates Ignite cache, if not exist. + + :param settings: cache name or dict of cache properties' codes + and values. All cache properties are documented here: + :ref:`cache_props`. See also the + :ref:`cache creation example `, + :return: :class:`~pyignite.cache.Cache` object. + """ + return await get_or_create_cache(self, settings) + + async def get_cache(self, settings: Union[str, dict]) -> 'AioCache': + """ + Creates Cache object with a given cache name without checking it up + on server. If such a cache does not exist, some kind of exception + (most probably `CacheError`) may be raised later. + + :param settings: cache name or cache properties (but only `PROP_NAME` + property is allowed), + :return: :class:`~pyignite.cache.Cache` object. + """ + return await get_cache(self, settings) + + @status_to_exception(CacheError) + async def get_cache_names(self) -> list: + """ + Gets existing cache names. + + :return: list of cache names. + """ + conn = await self.random_node() + return await cache_get_names_async(conn) + + def sql( + self, query_str: str, page_size: int = 1024, + query_args: Iterable = None, schema: str = 'PUBLIC', + statement_type: int = 0, distributed_joins: bool = False, + local: bool = False, replicated_only: bool = False, + enforce_join_order: bool = False, collocated: bool = False, + lazy: bool = False, include_field_names: bool = False, + max_rows: int = -1, timeout: int = 0, + cache: Union[int, str, 'AioCache'] = None + ): + """ + Runs an SQL query and returns its result. + + :param query_str: SQL query string, + :param page_size: (optional) cursor page size. Default is 1024, which + means that client makes one server call per 1024 rows, + :param query_args: (optional) query arguments. List of values or + (value, type hint) tuples, + :param schema: (optional) schema for the query. Defaults to `PUBLIC`, + :param statement_type: (optional) statement type. Can be: + + * StatementType.ALL − any type (default), + * StatementType.SELECT − select, + * StatementType.UPDATE − update. + + :param distributed_joins: (optional) distributed joins. Defaults + to False, + :param local: (optional) pass True if this query should be executed + on local node only. Defaults to False, + :param replicated_only: (optional) whether query contains only + replicated tables or not. Defaults to False, + :param enforce_join_order: (optional) enforce join order. Defaults + to False, + :param collocated: (optional) whether your data is co-located or not. + Defaults to False, + :param lazy: (optional) lazy query execution. Defaults to False, + :param include_field_names: (optional) include field names in result. + Defaults to False, + :param max_rows: (optional) query-wide maximum of rows. Defaults to -1 + (all rows), + :param timeout: (optional) non-negative timeout value in ms. + Zero disables timeout (default), + :param cache (optional) Name or ID of the cache to use to infer schema. + If set, 'schema' argument is ignored, + :return: generator with result rows as a lists. If + `include_field_names` was set, the first row will hold field names. + """ + + c_id = cache.cache_id if isinstance(cache, AioCache) else cache_id(cache) + + if c_id != 0: + schema = None + + return AioSqlFieldsCursor(self, c_id, query_str, page_size, query_args, schema, statement_type, + distributed_joins, local, replicated_only, enforce_join_order, collocated, + lazy, include_field_names, max_rows, timeout) diff --git a/pyignite/api/__init__.py b/pyignite/api/__init__.py index 7dbef0a..7deed8c 100644 --- a/pyignite/api/__init__.py +++ b/pyignite/api/__init__.py @@ -23,53 +23,55 @@ stable end user API see :mod:`pyignite.client` module. """ +# flake8: noqa + from .affinity import ( - cache_get_node_partitions, + cache_get_node_partitions, cache_get_node_partitions_async, ) from .cache_config import ( - cache_create, - cache_get_names, - cache_get_or_create, - cache_destroy, - cache_get_configuration, - cache_create_with_config, - cache_get_or_create_with_config, + cache_create, cache_create_async, + cache_get_names, cache_get_names_async, + cache_get_or_create, cache_get_or_create_async, + cache_destroy, cache_destroy_async, + cache_get_configuration, cache_get_configuration_async, + cache_create_with_config, cache_create_with_config_async, + cache_get_or_create_with_config, cache_get_or_create_with_config_async, ) from .key_value import ( - cache_get, - cache_put, - cache_get_all, - cache_put_all, - cache_contains_key, - cache_contains_keys, - cache_get_and_put, - cache_get_and_replace, - cache_get_and_remove, - cache_put_if_absent, - cache_get_and_put_if_absent, - cache_replace, - cache_replace_if_equals, - cache_clear, - cache_clear_key, - cache_clear_keys, - cache_remove_key, - cache_remove_if_equals, - cache_remove_keys, - cache_remove_all, - cache_get_size, - cache_local_peek, + cache_get, cache_get_async, + cache_put, cache_put_async, + cache_get_all, cache_get_all_async, + cache_put_all, cache_put_all_async, + cache_contains_key, cache_contains_key_async, + cache_contains_keys, cache_contains_keys_async, + cache_get_and_put, cache_get_and_put_async, + cache_get_and_replace, cache_get_and_replace_async, + cache_get_and_remove, cache_get_and_remove_async, + cache_put_if_absent, cache_put_if_absent_async, + cache_get_and_put_if_absent, cache_get_and_put_if_absent_async, + cache_replace, cache_replace_async, + cache_replace_if_equals, cache_replace_if_equals_async, + cache_clear, cache_clear_async, + cache_clear_key, cache_clear_key_async, + cache_clear_keys, cache_clear_keys_async, + cache_remove_key, cache_remove_key_async, + cache_remove_if_equals, cache_remove_if_equals_async, + cache_remove_keys, cache_remove_keys_async, + cache_remove_all, cache_remove_all_async, + cache_get_size, cache_get_size_async, + cache_local_peek, cache_local_peek_async, ) from .sql import ( - scan, - scan_cursor_get_page, + scan, scan_async, + scan_cursor_get_page, scan_cursor_get_page_async, sql, sql_cursor_get_page, - sql_fields, - sql_fields_cursor_get_page, - resource_close, + sql_fields, sql_fields_async, + sql_fields_cursor_get_page, sql_fields_cursor_get_page_async, + resource_close, resource_close_async ) from .binary import ( - get_binary_type, - put_binary_type, + get_binary_type, get_binary_type_async, + put_binary_type, put_binary_type_async ) from .result import APIResult diff --git a/pyignite/api/affinity.py b/pyignite/api/affinity.py index 7d09517..ddf1e7a 100644 --- a/pyignite/api/affinity.py +++ b/pyignite/api/affinity.py @@ -15,9 +15,10 @@ from typing import Iterable, Union +from pyignite.connection import AioConnection, Connection from pyignite.datatypes import Bool, Int, Long, UUIDObject from pyignite.datatypes.internal import StructArray, Conditional, Struct -from pyignite.queries import Query +from pyignite.queries import Query, query_perform from pyignite.queries.op_codes import OP_CACHE_PARTITIONS from pyignite.utils import is_iterable from .result import APIResult @@ -67,10 +68,7 @@ ]) -def cache_get_node_partitions( - conn: 'Connection', caches: Union[int, Iterable[int]], - query_id: int = None, -) -> APIResult: +def cache_get_node_partitions(conn: 'Connection', caches: Union[int, Iterable[int]], query_id: int = None) -> APIResult: """ Gets partition mapping for an Ignite cache or a number of caches. See “IEP-23: Best Effort Affinity for thin clients”. @@ -82,6 +80,62 @@ def cache_get_node_partitions( is generated, :return: API result data object. """ + return __cache_get_node_partitions(conn, caches, query_id) + + +async def cache_get_node_partitions_async(conn: 'AioConnection', caches: Union[int, Iterable[int]], + query_id: int = None) -> APIResult: + """ + Async version of cache_get_node_partitions. + """ + return await __cache_get_node_partitions(conn, caches, query_id) + + +def __post_process_partitions(result): + if result.status == 0: + # tidying up the result + value = { + 'version': ( + result.value['version_major'], + result.value['version_minor'] + ), + 'partition_mapping': {}, + } + for partition_map in result.value['partition_mapping']: + is_applicable = partition_map['is_applicable'] + + node_mapping = None + if is_applicable: + node_mapping = { + p['node_uuid']: set(x['partition_id'] for x in p['node_partitions']) + for p in partition_map['node_mapping'] + } + + for cache_info in partition_map['cache_mapping']: + cache_id = cache_info['cache_id'] + + cache_partition_mapping = { + 'is_applicable': is_applicable, + } + + parts = 0 + if is_applicable: + cache_partition_mapping['cache_config'] = { + a['key_type_id']: a['affinity_key_field_id'] + for a in cache_info['cache_config'] + } + cache_partition_mapping['node_mapping'] = node_mapping + + parts = sum(len(p) for p in cache_partition_mapping['node_mapping'].values()) + + cache_partition_mapping['number_of_partitions'] = parts + + value['partition_mapping'][cache_id] = cache_partition_mapping + result.value = value + return result + + +def __cache_get_node_partitions(conn, caches, query_id): query_struct = Query( OP_CACHE_PARTITIONS, [ @@ -92,7 +146,8 @@ def cache_get_node_partitions( if not is_iterable(caches): caches = [caches] - result = query_struct.perform( + return query_perform( + query_struct, conn, query_params={ 'cache_ids': [{'cache_id': cache} for cache in caches], @@ -102,36 +157,5 @@ def cache_get_node_partitions( ('version_minor', Int), ('partition_mapping', partition_mapping), ], + post_process_fun=__post_process_partitions ) - if result.status == 0: - # tidying up the result - value = { - 'version': ( - result.value['version_major'], - result.value['version_minor'] - ), - 'partition_mapping': [], - } - for i, partition_map in enumerate(result.value['partition_mapping']): - cache_id = partition_map['cache_mapping'][0]['cache_id'] - value['partition_mapping'].insert( - i, - { - 'cache_id': cache_id, - 'is_applicable': partition_map['is_applicable'], - } - ) - if partition_map['is_applicable']: - value['partition_mapping'][i]['cache_config'] = { - a['key_type_id']: a['affinity_key_field_id'] - for a in partition_map['cache_mapping'][0]['cache_config'] - } - value['partition_mapping'][i]['node_mapping'] = { - p['node_uuid']: [ - x['partition_id'] for x in p['node_partitions'] - ] - for p in partition_map['node_mapping'] - } - result.value = value - - return result diff --git a/pyignite/api/binary.py b/pyignite/api/binary.py index 87a5232..345e8e8 100644 --- a/pyignite/api/binary.py +++ b/pyignite/api/binary.py @@ -15,17 +15,15 @@ from typing import Union -from pyignite.constants import * -from pyignite.datatypes.binary import ( - body_struct, enum_struct, schema_struct, binary_fields_struct, -) +from pyignite.connection import Connection, AioConnection +from pyignite.constants import PROTOCOL_BYTE_ORDER +from pyignite.datatypes.binary import enum_struct, schema_struct, binary_fields_struct from pyignite.datatypes import String, Int, Bool -from pyignite.queries import Query -from pyignite.queries.op_codes import * +from pyignite.queries import Query, query_perform +from pyignite.queries.op_codes import OP_GET_BINARY_TYPE, OP_PUT_BINARY_TYPE from pyignite.utils import entity_id, schema_id from .result import APIResult -from ..stream import BinaryStream, READ_BACKWARD -from ..queries.response import Response +from ..queries.response import BinaryTypeResponse def get_binary_type(conn: 'Connection', binary_type: Union[str, int], query_id=None) -> APIResult: @@ -39,75 +37,33 @@ def get_binary_type(conn: 'Connection', binary_type: Union[str, int], query_id=N is generated, :return: API result data object. """ + return __get_binary_type(conn, binary_type, query_id) + +async def get_binary_type_async(conn: 'AioConnection', binary_type: Union[str, int], query_id=None) -> APIResult: + """ + Async version of get_binary_type. + """ + return await __get_binary_type(conn, binary_type, query_id) + + +def __get_binary_type(conn, binary_type, query_id): query_struct = Query( OP_GET_BINARY_TYPE, [ ('type_id', Int), ], query_id=query_id, + response_type=BinaryTypeResponse ) - with BinaryStream(conn) as stream: - query_struct.from_python(stream, { - 'type_id': entity_id(binary_type), - }) - conn.send(stream.getbuffer()) - - response_head_struct = Response(protocol_version=conn.get_protocol_version(), - following=[('type_exists', Bool)]) - - with BinaryStream(conn, conn.recv()) as stream: - init_pos = stream.tell() - response_head_type = response_head_struct.parse(stream) - response_head = stream.read_ctype(response_head_type, direction=READ_BACKWARD) - - response_parts = [] - if response_head.type_exists: - resp_body_type = body_struct.parse(stream) - response_parts.append(('body', resp_body_type)) - resp_body = stream.read_ctype(resp_body_type, direction=READ_BACKWARD) - if resp_body.is_enum: - resp_enum = enum_struct.parse(stream) - response_parts.append(('enums', resp_enum)) - - resp_schema_type = schema_struct.parse(stream) - response_parts.append(('schema', resp_schema_type)) - - response_class = type( - 'GetBinaryTypeResponse', - (response_head_type,), - { - '_pack_': 1, - '_fields_': response_parts, - } - ) - response = stream.read_ctype(response_class, position=init_pos) + return query_perform(query_struct, conn, query_params={ + 'type_id': entity_id(binary_type), + }) - result = APIResult(response) - if result.status != 0: - return result - result.value = { - 'type_exists': Bool.to_python(response.type_exists) - } - if hasattr(response, 'body'): - result.value.update(body_struct.to_python(response.body)) - if hasattr(response, 'enums'): - result.value['enums'] = enum_struct.to_python(response.enums) - if hasattr(response, 'schema'): - result.value['schema'] = { - x['schema_id']: [ - z['schema_field_id'] for z in x['schema_fields'] - ] - for x in schema_struct.to_python(response.schema) - } - return result - - -def put_binary_type( - connection: 'Connection', type_name: str, affinity_key_field: str=None, - is_enum=False, schema: dict=None, query_id=None, -) -> APIResult: + +def put_binary_type(connection: 'Connection', type_name: str, affinity_key_field: str = None, + is_enum=False, schema: dict = None, query_id=None) -> APIResult: """ Registers binary type information in cluster. @@ -125,6 +81,29 @@ def put_binary_type( is generated, :return: API result data object. """ + return __put_binary_type(connection, type_name, affinity_key_field, is_enum, schema, query_id) + + +async def put_binary_type_async(connection: 'AioConnection', type_name: str, affinity_key_field: str = None, + is_enum=False, schema: dict = None, query_id=None) -> APIResult: + """ + Async version of put_binary_type. + """ + return await __put_binary_type(connection, type_name, affinity_key_field, is_enum, schema, query_id) + + +def __post_process_put_binary(type_id): + def internal(result): + if result.status == 0: + result.value = { + 'type_id': type_id, + 'schema_id': schema_id, + } + return result + return internal + + +def __put_binary_type(connection, type_name, affinity_key_field, is_enum, schema, query_id): # prepare data if schema is None: schema = {} @@ -195,10 +174,5 @@ def put_binary_type( ], query_id=query_id, ) - result = query_struct.perform(connection, query_params=data) - if result.status == 0: - result.value = { - 'type_id': type_id, - 'schema_id': schema_id, - } - return result + return query_perform(query_struct, connection, query_params=data, + post_process_fun=__post_process_put_binary(type_id)) diff --git a/pyignite/api/cache_config.py b/pyignite/api/cache_config.py index cfea416..0adb549 100644 --- a/pyignite/api/cache_config.py +++ b/pyignite/api/cache_config.py @@ -25,15 +25,19 @@ from typing import Union +from pyignite.connection import Connection, AioConnection from pyignite.datatypes.cache_config import cache_config_struct from pyignite.datatypes.cache_properties import prop_map -from pyignite.datatypes import ( - Int, Byte, prop_codes, Short, String, StringArray, +from pyignite.datatypes import Int, Byte, prop_codes, Short, String, StringArray +from pyignite.queries import Query, ConfigQuery, query_perform +from pyignite.queries.op_codes import ( + OP_CACHE_GET_CONFIGURATION, OP_CACHE_CREATE_WITH_NAME, OP_CACHE_GET_OR_CREATE_WITH_NAME, OP_CACHE_DESTROY, + OP_CACHE_GET_NAMES, OP_CACHE_CREATE_WITH_CONFIGURATION, OP_CACHE_GET_OR_CREATE_WITH_CONFIGURATION ) -from pyignite.queries import Query, ConfigQuery -from pyignite.queries.op_codes import * from pyignite.utils import cache_id +from .result import APIResult + def compact_cache_config(cache_config: dict) -> dict: """ @@ -48,14 +52,13 @@ def compact_cache_config(cache_config: dict) -> dict: for k, v in cache_config.items(): if k == 'length': continue - prop_code = getattr(prop_codes, 'PROP_{}'.format(k.upper())) + prop_code = getattr(prop_codes, f'PROP_{k.upper()}') result[prop_code] = v return result -def cache_get_configuration( - connection: 'Connection', cache: Union[str, int], flags: int=0, query_id=None, -) -> 'APIResult': +def cache_get_configuration(connection: 'Connection', cache: Union[str, int], + flags: int = 0, query_id=None) -> 'APIResult': """ Gets configuration for the given cache. @@ -68,7 +71,24 @@ def cache_get_configuration( :return: API result data object. Result value is OrderedDict with the cache configuration parameters. """ + return __cache_get_configuration(connection, cache, flags, query_id) + + +async def cache_get_configuration_async(connection: 'AioConnection', cache: Union[str, int], + flags: int = 0, query_id=None) -> 'APIResult': + """ + Async version of cache_get_configuration. + """ + return await __cache_get_configuration(connection, cache, flags, query_id) + + +def __post_process_cache_config(result): + if result.status == 0: + result.value = compact_cache_config(result.value['cache_config']) + return result + +def __cache_get_configuration(connection, cache, flags, query_id): query_struct = Query( OP_CACHE_GET_CONFIGURATION, [ @@ -77,24 +97,19 @@ def cache_get_configuration( ], query_id=query_id, ) - result = query_struct.perform( - connection, - query_params={ - 'hash_code': cache_id(cache), - 'flags': flags, - }, - response_config=[ - ('cache_config', cache_config_struct), - ], - ) - if result.status == 0: - result.value = compact_cache_config(result.value['cache_config']) - return result - - -def cache_create( - connection: 'Connection', name: str, query_id=None, -) -> 'APIResult': + return query_perform(query_struct, connection, + query_params={ + 'hash_code': cache_id(cache), + 'flags': flags + }, + response_config=[ + ('cache_config', cache_config_struct) + ], + post_process_fun=__post_process_cache_config + ) + + +def cache_create(connection: 'Connection', name: str, query_id=None) -> 'APIResult': """ Creates a cache with a given name. Returns error if a cache with specified name already exists. @@ -108,24 +123,18 @@ def cache_create( created successfully, non-zero status and an error description otherwise. """ - query_struct = Query( - OP_CACHE_CREATE_WITH_NAME, - [ - ('cache_name', String), - ], - query_id=query_id, - ) - return query_struct.perform( - connection, - query_params={ - 'cache_name': name, - }, - ) + return __cache_create_with_name(OP_CACHE_CREATE_WITH_NAME, connection, name, query_id) -def cache_get_or_create( - connection: 'Connection', name: str, query_id=None, -) -> 'APIResult': +async def cache_create_async(connection: 'AioConnection', name: str, query_id=None) -> 'APIResult': + """ + Async version of cache_create. + """ + + return await __cache_create_with_name(OP_CACHE_CREATE_WITH_NAME, connection, name, query_id) + + +def cache_get_or_create(connection: 'Connection', name: str, query_id=None) -> 'APIResult': """ Creates a cache with a given name. Does nothing if the cache exists. @@ -138,24 +147,22 @@ def cache_get_or_create( created successfully, non-zero status and an error description otherwise. """ - query_struct = Query( - OP_CACHE_GET_OR_CREATE_WITH_NAME, - [ - ('cache_name', String), - ], - query_id=query_id, - ) - return query_struct.perform( - connection, - query_params={ - 'cache_name': name, - }, - ) + return __cache_create_with_name(OP_CACHE_GET_OR_CREATE_WITH_NAME, connection, name, query_id) + + +async def cache_get_or_create_async(connection: 'AioConnection', name: str, query_id=None) -> 'APIResult': + """ + Async version of cache_get_or_create. + """ + return await __cache_create_with_name(OP_CACHE_GET_OR_CREATE_WITH_NAME, connection, name, query_id) + +def __cache_create_with_name(op_code, conn, name, query_id): + query_struct = Query(op_code, [('cache_name', String)], query_id=query_id) + return query_perform(query_struct, conn, query_params={'cache_name': name}) -def cache_destroy( - connection: 'Connection', cache: Union[str, int], query_id=None, -) -> 'APIResult': + +def cache_destroy(connection: 'Connection', cache: Union[str, int], query_id=None) -> 'APIResult': """ Destroys cache with a given name. @@ -166,19 +173,20 @@ def cache_destroy( is generated, :return: API result data object. """ + return __cache_destroy(connection, cache, query_id) - query_struct = Query( - OP_CACHE_DESTROY,[ - ('hash_code', Int), - ], - query_id=query_id, - ) - return query_struct.perform( - connection, - query_params={ - 'hash_code': cache_id(cache), - }, - ) + +async def cache_destroy_async(connection: 'AioConnection', cache: Union[str, int], query_id=None) -> 'APIResult': + """ + Async version of cache_destroy. + """ + return await __cache_destroy(connection, cache, query_id) + + +def __cache_destroy(connection, cache, query_id): + query_struct = Query(OP_CACHE_DESTROY, [('hash_code', Int)], query_id=query_id) + + return query_perform(query_struct, connection, query_params={'hash_code': cache_id(cache)}) def cache_get_names(connection: 'Connection', query_id=None) -> 'APIResult': @@ -193,21 +201,30 @@ def cache_get_names(connection: 'Connection', query_id=None) -> 'APIResult': names, non-zero status and an error description otherwise. """ - query_struct = Query(OP_CACHE_GET_NAMES, query_id=query_id) - result = query_struct.perform( - connection, - response_config=[ - ('cache_names', StringArray), - ], - ) + return __cache_get_names(connection, query_id) + + +async def cache_get_names_async(connection: 'AioConnection', query_id=None) -> 'APIResult': + """ + Async version of cache_get_names. + """ + return await __cache_get_names(connection, query_id) + + +def __post_process_cache_names(result): if result.status == 0: result.value = result.value['cache_names'] return result -def cache_create_with_config( - connection: 'Connection', cache_props: dict, query_id=None, -) -> 'APIResult': +def __cache_get_names(connection, query_id): + query_struct = Query(OP_CACHE_GET_NAMES, query_id=query_id) + return query_perform(query_struct, connection, + response_config=[('cache_names', StringArray)], + post_process_fun=__post_process_cache_names) + + +def cache_create_with_config(connection: 'Connection', cache_props: dict, query_id=None) -> 'APIResult': """ Creates cache with provided configuration. An error is returned if the name is already in use. @@ -222,29 +239,17 @@ def cache_create_with_config( :return: API result data object. Contains zero status if cache was created, non-zero status and an error description otherwise. """ + return __cache_create_with_config(OP_CACHE_CREATE_WITH_CONFIGURATION, connection, cache_props, query_id) - prop_types = {} - prop_values = {} - for i, prop_item in enumerate(cache_props.items()): - prop_code, prop_value = prop_item - prop_name = 'property_{}'.format(i) - prop_types[prop_name] = prop_map(prop_code) - prop_values[prop_name] = prop_value - prop_values['param_count'] = len(cache_props) - query_struct = ConfigQuery( - OP_CACHE_CREATE_WITH_CONFIGURATION, - [ - ('param_count', Short), - ] + list(prop_types.items()), - query_id=query_id, - ) - return query_struct.perform(connection, query_params=prop_values) +async def cache_create_with_config_async(connection: 'AioConnection', cache_props: dict, query_id=None) -> 'APIResult': + """ + Async version of cache_create_with_config. + """ + return await __cache_create_with_config(OP_CACHE_CREATE_WITH_CONFIGURATION, connection, cache_props, query_id) -def cache_get_or_create_with_config( - connection: 'Connection', cache_props: dict, query_id=None, -) -> 'APIResult': +def cache_get_or_create_with_config(connection: 'Connection', cache_props: dict, query_id=None) -> 'APIResult': """ Creates cache with provided configuration. Does nothing if the name is already in use. @@ -259,9 +264,20 @@ def cache_get_or_create_with_config( :return: API result data object. Contains zero status if cache was created, non-zero status and an error description otherwise. """ + return __cache_create_with_config(OP_CACHE_GET_OR_CREATE_WITH_CONFIGURATION, connection, cache_props, query_id) + + +async def cache_get_or_create_with_config_async(connection: 'AioConnection', cache_props: dict, + query_id=None) -> 'APIResult': + """ + Async version of cache_get_or_create_with_config. + """ + return await __cache_create_with_config(OP_CACHE_GET_OR_CREATE_WITH_CONFIGURATION, connection, cache_props, + query_id) + - prop_types = {} - prop_values = {} +def __cache_create_with_config(op_code, connection, cache_props, query_id): + prop_types, prop_values = {}, {} for i, prop_item in enumerate(cache_props.items()): prop_code, prop_value = prop_item prop_name = 'property_{}'.format(i) @@ -269,11 +285,6 @@ def cache_get_or_create_with_config( prop_values[prop_name] = prop_value prop_values['param_count'] = len(cache_props) - query_struct = ConfigQuery( - OP_CACHE_GET_OR_CREATE_WITH_CONFIGURATION, - [ - ('param_count', Short), - ] + list(prop_types.items()), - query_id=query_id, - ) - return query_struct.perform(connection, query_params=prop_values) + following = [('param_count', Short)] + list(prop_types.items()) + query_struct = ConfigQuery(op_code, following, query_id=query_id) + return query_perform(query_struct, connection, query_params=prop_values) diff --git a/pyignite/api/key_value.py b/pyignite/api/key_value.py index 25601e9..6d5663c 100644 --- a/pyignite/api/key_value.py +++ b/pyignite/api/key_value.py @@ -15,20 +15,26 @@ from typing import Any, Iterable, Optional, Union -from pyignite.queries.op_codes import * -from pyignite.datatypes import ( - Map, Bool, Byte, Int, Long, AnyDataArray, AnyDataObject, +from pyignite.connection import AioConnection, Connection +from pyignite.queries.op_codes import ( + OP_CACHE_PUT, OP_CACHE_GET, OP_CACHE_GET_ALL, OP_CACHE_PUT_ALL, OP_CACHE_CONTAINS_KEY, OP_CACHE_CONTAINS_KEYS, + OP_CACHE_GET_AND_PUT, OP_CACHE_GET_AND_REPLACE, OP_CACHE_GET_AND_REMOVE, OP_CACHE_PUT_IF_ABSENT, + OP_CACHE_GET_AND_PUT_IF_ABSENT, OP_CACHE_REPLACE, OP_CACHE_REPLACE_IF_EQUALS, OP_CACHE_CLEAR, OP_CACHE_CLEAR_KEY, + OP_CACHE_CLEAR_KEYS, OP_CACHE_REMOVE_KEY, OP_CACHE_REMOVE_IF_EQUALS, OP_CACHE_REMOVE_KEYS, OP_CACHE_REMOVE_ALL, + OP_CACHE_GET_SIZE, OP_CACHE_LOCAL_PEEK ) +from pyignite.datatypes import Map, Bool, Byte, Int, Long, AnyDataArray, AnyDataObject +from pyignite.datatypes.base import IgniteDataType from pyignite.datatypes.key_value import PeekModes -from pyignite.queries import Query +from pyignite.queries import Query, query_perform from pyignite.utils import cache_id +from .result import APIResult -def cache_put( - connection: 'Connection', cache: Union[str, int], key: Any, value: Any, - key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None, - binary: bool = False, query_id: Optional[int] = None, -) -> 'APIResult': + +def cache_put(connection: 'Connection', cache: Union[str, int], key: Any, value: Any, + key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None, binary: bool = False, + query_id: Optional[int] = None) -> 'APIResult': """ Puts a value with a given key to cache (overwriting existing value if any). @@ -48,7 +54,19 @@ def cache_put( :return: API result data object. Contains zero status if a value is written, non-zero status and an error description otherwise. """ + return __cache_put(connection, cache, key, value, key_hint, value_hint, binary, query_id) + + +async def cache_put_async(connection: 'AioConnection', cache: Union[str, int], key: Any, value: Any, + key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None, binary: bool = False, + query_id: Optional[int] = None) -> 'APIResult': + """ + Async version of cache_put + """ + return await __cache_put(connection, cache, key, value, key_hint, value_hint, binary, query_id) + +def __cache_put(connection, cache, key, value, key_hint, value_hint, binary, query_id): query_struct = Query( OP_CACHE_PUT, [ @@ -59,19 +77,19 @@ def cache_put( ], query_id=query_id, ) - return query_struct.perform(connection, { - 'hash_code': cache_id(cache), - 'flag': 1 if binary else 0, - 'key': key, - 'value': value, - }) + return query_perform( + query_struct, connection, + query_params={ + 'hash_code': cache_id(cache), + 'flag': 1 if binary else 0, + 'key': key, + 'value': value + } + ) -def cache_get( - connection: 'Connection', cache: Union[str, int], key: Any, - key_hint: 'IgniteDataType' = None, - binary: bool = False, query_id: Optional[int] = None, -) -> 'APIResult': +def cache_get(connection: 'Connection', cache: Union[str, int], key: Any, key_hint: 'IgniteDataType' = None, + binary: bool = False, query_id: Optional[int] = None) -> 'APIResult': """ Retrieves a value from cache by key. @@ -88,7 +106,19 @@ def cache_get( :return: API result data object. Contains zero status and a value retrieved on success, non-zero status and an error description on failure. """ + return __cache_get(connection, cache, key, key_hint, binary, query_id) + + +async def cache_get_async(connection: 'AioConnection', cache: Union[str, int], key: Any, + key_hint: 'IgniteDataType' = None, binary: bool = False, + query_id: Optional[int] = None) -> 'APIResult': + """ + Async version of cache_get + """ + return await __cache_get(connection, cache, key, key_hint, binary, query_id) + +def __cache_get(connection, cache, key, key_hint, binary, query_id): query_struct = Query( OP_CACHE_GET, [ @@ -98,27 +128,22 @@ def cache_get( ], query_id=query_id, ) - result = query_struct.perform( - connection, + return query_perform( + query_struct, connection, query_params={ 'hash_code': cache_id(cache), 'flag': 1 if binary else 0, 'key': key, }, response_config=[ - ('value', AnyDataObject), + ('value', AnyDataObject), ], + post_process_fun=__post_process_value_by_key('value') ) - if result.status != 0: - return result - result.value = result.value['value'] - return result -def cache_get_all( - connection: 'Connection', cache: Union[str, int], keys: Iterable, - binary: bool = False, query_id: Optional[int] = None, -) -> 'APIResult': +def cache_get_all(connection: 'Connection', cache: Union[str, int], keys: Iterable, binary: bool = False, + query_id: Optional[int] = None) -> 'APIResult': """ Retrieves multiple key-value pairs from cache. @@ -134,7 +159,18 @@ def cache_get_all( retrieved key-value pairs, non-zero status and an error description on failure. """ + return __cache_get_all(connection, cache, keys, binary, query_id) + + +async def cache_get_all_async(connection: 'AioConnection', cache: Union[str, int], keys: Iterable, binary: bool = False, + query_id: Optional[int] = None) -> 'APIResult': + """ + Async version of cache_get_all. + """ + return await __cache_get_all(connection, cache, keys, binary, query_id) + +def __cache_get_all(connection, cache, keys, binary, query_id): query_struct = Query( OP_CACHE_GET_ALL, [ @@ -144,8 +180,8 @@ def cache_get_all( ], query_id=query_id, ) - result = query_struct.perform( - connection, + return query_perform( + query_struct, connection, query_params={ 'hash_code': cache_id(cache), 'flag': 1 if binary else 0, @@ -154,16 +190,12 @@ def cache_get_all( response_config=[ ('data', Map), ], + post_process_fun=__post_process_value_by_key('data') ) - if result.status == 0: - result.value = dict(result.value)['data'] - return result -def cache_put_all( - connection: 'Connection', cache: Union[str, int], pairs: dict, - binary: bool = False, query_id: Optional[int] = None, -) -> 'APIResult': +def cache_put_all(connection: 'Connection', cache: Union[str, int], pairs: dict, binary: bool = False, + query_id: Optional[int] = None) -> 'APIResult': """ Puts multiple key-value pairs to cache (overwriting existing associations if any). @@ -181,7 +213,18 @@ def cache_put_all( :return: API result data object. Contains zero status if key-value pairs are written, non-zero status and an error description otherwise. """ + return __cache_put_all(connection, cache, pairs, binary, query_id) + +async def cache_put_all_async(connection: 'AioConnection', cache: Union[str, int], pairs: dict, binary: bool = False, + query_id: Optional[int] = None) -> 'APIResult': + """ + Async version of cache_put_all. + """ + return await __cache_put_all(connection, cache, pairs, binary, query_id) + + +def __cache_put_all(connection, cache, pairs, binary, query_id): query_struct = Query( OP_CACHE_PUT_ALL, [ @@ -191,8 +234,8 @@ def cache_put_all( ], query_id=query_id, ) - return query_struct.perform( - connection, + return query_perform( + query_struct, connection, query_params={ 'hash_code': cache_id(cache), 'flag': 1 if binary else 0, @@ -201,11 +244,8 @@ def cache_put_all( ) -def cache_contains_key( - connection: 'Connection', cache: Union[str, int], key: Any, - key_hint: 'IgniteDataType' = None, - binary: bool = False, query_id: Optional[int] = None, -) -> 'APIResult': +def cache_contains_key(connection: 'Connection', cache: Union[str, int], key: Any, key_hint: 'IgniteDataType' = None, + binary: bool = False, query_id: Optional[int] = None) -> 'APIResult': """ Returns a value indicating whether given key is present in cache. @@ -223,7 +263,19 @@ def cache_contains_key( retrieved on success: `True` when key is present, `False` otherwise, non-zero status and an error description on failure. """ + return __cache_contains_key(connection, cache, key, key_hint, binary, query_id) + + +async def cache_contains_key_async(connection: 'AioConnection', cache: Union[str, int], key: Any, + key_hint: 'IgniteDataType' = None, binary: bool = False, + query_id: Optional[int] = None) -> 'APIResult': + """ + Async version of cache_contains_key. + """ + return await __cache_contains_key(connection, cache, key, key_hint, binary, query_id) + +def __cache_contains_key(connection, cache, key, key_hint, binary, query_id): query_struct = Query( OP_CACHE_CONTAINS_KEY, [ @@ -233,9 +285,9 @@ def cache_contains_key( ], query_id=query_id, ) - result = query_struct.perform( - connection, - query_params={ + return query_perform( + query_struct, connection, + query_params={ 'hash_code': cache_id(cache), 'flag': 1 if binary else 0, 'key': key, @@ -243,16 +295,12 @@ def cache_contains_key( response_config=[ ('value', Bool), ], + post_process_fun=__post_process_value_by_key('value') ) - if result.status == 0: - result.value = result.value['value'] - return result -def cache_contains_keys( - connection: 'Connection', cache: Union[str, int], keys: Iterable, - binary: bool = False, query_id: Optional[int] = None, -) -> 'APIResult': +def cache_contains_keys(connection: 'Connection', cache: Union[str, int], keys: Iterable, binary: bool = False, + query_id: Optional[int] = None) -> 'APIResult': """ Returns a value indicating whether all given keys are present in cache. @@ -268,7 +316,18 @@ def cache_contains_keys( retrieved on success: `True` when all keys are present, `False` otherwise, non-zero status and an error description on failure. """ + return __cache_contains_keys(connection, cache, keys, binary, query_id) + + +async def cache_contains_keys_async(connection: 'AioConnection', cache: Union[str, int], keys: Iterable, + binary: bool = False, query_id: Optional[int] = None) -> 'APIResult': + """ + Async version of cache_contains_keys. + """ + return await __cache_contains_keys(connection, cache, keys, binary, query_id) + +def __cache_contains_keys(connection, cache, keys, binary, query_id): query_struct = Query( OP_CACHE_CONTAINS_KEYS, [ @@ -278,8 +337,8 @@ def cache_contains_keys( ], query_id=query_id, ) - result = query_struct.perform( - connection, + return query_perform( + query_struct, connection, query_params={ 'hash_code': cache_id(cache), 'flag': 1 if binary else 0, @@ -288,17 +347,13 @@ def cache_contains_keys( response_config=[ ('value', Bool), ], + post_process_fun=__post_process_value_by_key('value') ) - if result.status == 0: - result.value = result.value['value'] - return result -def cache_get_and_put( - connection: 'Connection', cache: Union[str, int], key: Any, value: Any, - key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None, - binary: bool = False, query_id: Optional[int] = None, -) -> 'APIResult': +def cache_get_and_put(connection: 'Connection', cache: Union[str, int], key: Any, value: Any, + key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None, binary: bool = False, + query_id: Optional[int] = None) -> 'APIResult': """ Puts a value with a given key to cache, and returns the previous value for that key, or null value if there was not such key. @@ -320,7 +375,19 @@ def cache_get_and_put( or None if a value is written, non-zero status and an error description in case of error. """ + return __cache_get_and_put(connection, cache, key, value, key_hint, value_hint, binary, query_id) + +async def cache_get_and_put_async(connection: 'AioConnection', cache: Union[str, int], key: Any, value: Any, + key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None, + binary: bool = False, query_id: Optional[int] = None) -> 'APIResult': + """ + Async version of cache_get_and_put. + """ + return await __cache_get_and_put(connection, cache, key, value, key_hint, value_hint, binary, query_id) + + +def __cache_get_and_put(connection, cache, key, value, key_hint, value_hint, binary, query_id): query_struct = Query( OP_CACHE_GET_AND_PUT, [ @@ -331,8 +398,8 @@ def cache_get_and_put( ], query_id=query_id, ) - result = query_struct.perform( - connection, + return query_perform( + query_struct, connection, query_params={ 'hash_code': cache_id(cache), 'flag': 1 if binary else 0, @@ -342,17 +409,13 @@ def cache_get_and_put( response_config=[ ('value', AnyDataObject), ], + post_process_fun=__post_process_value_by_key('value') ) - if result.status == 0: - result.value = result.value['value'] - return result -def cache_get_and_replace( - connection: 'Connection', cache: Union[str, int], key: Any, value: Any, - key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None, - binary: bool = False, query_id: Optional[int] = None, -) -> 'APIResult': +def cache_get_and_replace(connection: 'Connection', cache: Union[str, int], key: Any, value: Any, + key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None, binary: bool = False, + query_id: Optional[int] = None) -> 'APIResult': """ Puts a value with a given key to cache, returning previous value for that key, if and only if there is a value currently mapped @@ -374,7 +437,19 @@ def cache_get_and_replace( :return: API result data object. Contains zero status and an old value or None on success, non-zero status and an error description otherwise. """ + return __cache_get_and_replace(connection, cache, key, key_hint, value, value_hint, binary, query_id) + + +async def cache_get_and_replace_async(connection: 'AioConnection', cache: Union[str, int], key: Any, value: Any, + key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None, + binary: bool = False, query_id: Optional[int] = None) -> 'APIResult': + """ + Async version of cache_get_and_replace. + """ + return await __cache_get_and_replace(connection, cache, key, key_hint, value, value_hint, binary, query_id) + +def __cache_get_and_replace(connection, cache, key, key_hint, value, value_hint, binary, query_id): query_struct = Query( OP_CACHE_GET_AND_REPLACE, [ ('hash_code', Int), @@ -384,8 +459,8 @@ def cache_get_and_replace( ], query_id=query_id, ) - result = query_struct.perform( - connection, + return query_perform( + query_struct, connection, query_params={ 'hash_code': cache_id(cache), 'flag': 1 if binary else 0, @@ -395,17 +470,12 @@ def cache_get_and_replace( response_config=[ ('value', AnyDataObject), ], + post_process_fun=__post_process_value_by_key('value') ) - if result.status == 0: - result.value = result.value['value'] - return result -def cache_get_and_remove( - connection: 'Connection', cache: Union[str, int], key: Any, - key_hint: 'IgniteDataType' = None, - binary: bool = False, query_id: Optional[int] = None, -) -> 'APIResult': +def cache_get_and_remove(connection: 'Connection', cache: Union[str, int], key: Any, key_hint: 'IgniteDataType' = None, + binary: bool = False, query_id: Optional[int] = None) -> 'APIResult': """ Removes the cache entry with specified key, returning the value. @@ -422,7 +492,16 @@ def cache_get_and_remove( :return: API result data object. Contains zero status and an old value or None, non-zero status and an error description otherwise. """ + return __cache_get_and_remove(connection, cache, key, key_hint, binary, query_id) + +async def cache_get_and_remove_async(connection: 'AioConnection', cache: Union[str, int], key: Any, + key_hint: 'IgniteDataType' = None, binary: bool = False, + query_id: Optional[int] = None) -> 'APIResult': + return await __cache_get_and_remove(connection, cache, key, key_hint, binary, query_id) + + +def __cache_get_and_remove(connection, cache, key, key_hint, binary, query_id): query_struct = Query( OP_CACHE_GET_AND_REMOVE, [ ('hash_code', Int), @@ -431,8 +510,8 @@ def cache_get_and_remove( ], query_id=query_id, ) - result = query_struct.perform( - connection, + return query_perform( + query_struct, connection, query_params={ 'hash_code': cache_id(cache), 'flag': 1 if binary else 0, @@ -441,17 +520,13 @@ def cache_get_and_remove( response_config=[ ('value', AnyDataObject), ], + post_process_fun=__post_process_value_by_key('value') ) - if result.status == 0: - result.value = result.value['value'] - return result -def cache_put_if_absent( - connection: 'Connection', cache: Union[str, int], key: Any, value: Any, - key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None, - binary: bool = False, query_id: Optional[int] = None, -) -> 'APIResult': +def cache_put_if_absent(connection: 'Connection', cache: Union[str, int], key: Any, value: Any, + key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None, + binary: bool = False, query_id: Optional[int] = None) -> 'APIResult': """ Puts a value with a given key to cache only if the key does not already exist. @@ -472,7 +547,19 @@ def cache_put_if_absent( :return: API result data object. Contains zero status on success, non-zero status and an error description otherwise. """ + return __cache_put_if_absent(connection, cache, key, value, key_hint, value_hint, binary, query_id) + +async def cache_put_if_absent_async(connection: 'AioConnection', cache: Union[str, int], key: Any, value: Any, + key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None, + binary: bool = False, query_id: Optional[int] = None) -> 'APIResult': + """ + Async version of cache_put_if_absent. + """ + return await __cache_put_if_absent(connection, cache, key, value, key_hint, value_hint, binary, query_id) + + +def __cache_put_if_absent(connection, cache, key, value, key_hint, value_hint, binary, query_id): query_struct = Query( OP_CACHE_PUT_IF_ABSENT, [ @@ -483,8 +570,8 @@ def cache_put_if_absent( ], query_id=query_id, ) - result = query_struct.perform( - connection, + return query_perform( + query_struct, connection, query_params={ 'hash_code': cache_id(cache), 'flag': 1 if binary else 0, @@ -494,17 +581,13 @@ def cache_put_if_absent( response_config=[ ('success', Bool), ], + post_process_fun=__post_process_value_by_key('success') ) - if result.status == 0: - result.value = result.value['success'] - return result -def cache_get_and_put_if_absent( - connection: 'Connection', cache: Union[str, int], key: Any, value: Any, - key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None, - binary: bool = False, query_id: Optional[int] = None, -) -> 'APIResult': +def cache_get_and_put_if_absent(connection: 'Connection', cache: Union[str, int], key: Any, value: Any, + key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None, + binary: bool = False, query_id: Optional[int] = None) -> 'APIResult': """ Puts a value with a given key to cache only if the key does not already exist. @@ -525,7 +608,19 @@ def cache_get_and_put_if_absent( :return: API result data object. Contains zero status and an old value or None on success, non-zero status and an error description otherwise. """ + return __cache_get_and_put_if_absent(connection, cache, key, value, key_hint, value_hint, binary, query_id) + + +async def cache_get_and_put_if_absent_async(connection: 'AioConnection', cache: Union[str, int], key: Any, value: Any, + key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None, + binary: bool = False, query_id: Optional[int] = None) -> 'APIResult': + """ + Async version of cache_get_and_put_if_absent. + """ + return await __cache_get_and_put_if_absent(connection, cache, key, value, key_hint, value_hint, binary, query_id) + +def __cache_get_and_put_if_absent(connection, cache, key, value, key_hint, value_hint, binary, query_id): query_struct = Query( OP_CACHE_GET_AND_PUT_IF_ABSENT, [ @@ -536,8 +631,8 @@ def cache_get_and_put_if_absent( ], query_id=query_id, ) - result = query_struct.perform( - connection, + return query_perform( + query_struct, connection, query_params={ 'hash_code': cache_id(cache), 'flag': 1 if binary else 0, @@ -547,17 +642,13 @@ def cache_get_and_put_if_absent( response_config=[ ('value', AnyDataObject), ], + post_process_fun=__post_process_value_by_key('value') ) - if result.status == 0: - result.value = result.value['value'] - return result -def cache_replace( - connection: 'Connection', cache: Union[str, int], key: Any, value: Any, - key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None, - binary: bool = False, query_id: Optional[int] = None, -) -> 'APIResult': +def cache_replace(connection: 'Connection', cache: Union[str, int], key: Any, value: Any, + key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None, binary: bool = False, + query_id: Optional[int] = None) -> 'APIResult': """ Puts a value with a given key to cache only if the key already exist. @@ -578,7 +669,19 @@ def cache_replace( success code, or non-zero status and an error description if something has gone wrong. """ + return __cache_replace(connection, cache, key, value, key_hint, value_hint, binary, query_id) + + +async def cache_replace_async(connection: 'AioConnection', cache: Union[str, int], key: Any, value: Any, + key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None, + binary: bool = False, query_id: Optional[int] = None) -> 'APIResult': + """ + Async version of cache_replace. + """ + return await __cache_replace(connection, cache, key, value, key_hint, value_hint, binary, query_id) + +def __cache_replace(connection, cache, key, value, key_hint, value_hint, binary, query_id): query_struct = Query( OP_CACHE_REPLACE, [ @@ -589,8 +692,8 @@ def cache_replace( ], query_id=query_id, ) - result = query_struct.perform( - connection, + return query_perform( + query_struct, connection, query_params={ 'hash_code': cache_id(cache), 'flag': 1 if binary else 0, @@ -600,18 +703,14 @@ def cache_replace( response_config=[ ('success', Bool), ], + post_process_fun=__post_process_value_by_key('success') ) - if result.status == 0: - result.value = result.value['success'] - return result -def cache_replace_if_equals( - connection: 'Connection', cache: Union[str, int], - key: Any, sample: Any, value: Any, key_hint: 'IgniteDatatType' = None, - sample_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None, - binary: bool = False, query_id: Optional[int] = None, -) -> 'APIResult': +def cache_replace_if_equals(connection: 'Connection', cache: Union[str, int], key: Any, sample: Any, value: Any, + key_hint: 'IgniteDataType' = None, sample_hint: 'IgniteDataType' = None, + value_hint: 'IgniteDataType' = None, binary: bool = False, + query_id: Optional[int] = None) -> 'APIResult': """ Puts a value with a given key to cache only if the key already exists and value equals provided sample. @@ -636,7 +735,23 @@ def cache_replace_if_equals( success code, or non-zero status and an error description if something has gone wrong. """ + return __cache_replace_if_equals(connection, cache, key, sample, value, key_hint, sample_hint, value_hint, binary, + query_id) + +async def cache_replace_if_equals_async( + connection: 'AioConnection', cache: Union[str, int], key: Any, sample: Any, value: Any, + key_hint: 'IgniteDataType' = None, sample_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None, + binary: bool = False, query_id: Optional[int] = None) -> 'APIResult': + """ + Async version of cache_replace_if_equals. + """ + return await __cache_replace_if_equals(connection, cache, key, sample, value, key_hint, sample_hint, value_hint, + binary, query_id) + + +def __cache_replace_if_equals(connection, cache, key, sample, value, key_hint, sample_hint, value_hint, binary, + query_id): query_struct = Query( OP_CACHE_REPLACE_IF_EQUALS, [ @@ -648,8 +763,8 @@ def cache_replace_if_equals( ], query_id=query_id, ) - result = query_struct.perform( - connection, + return query_perform( + query_struct, connection, query_params={ 'hash_code': cache_id(cache), 'flag': 1 if binary else 0, @@ -660,16 +775,12 @@ def cache_replace_if_equals( response_config=[ ('success', Bool), ], + post_process_fun=__post_process_value_by_key('success') ) - if result.status == 0: - result.value = result.value['success'] - return result -def cache_clear( - connection: 'Connection', cache: Union[str, int], - binary: bool = False, query_id: Optional[int] = None, -) -> 'APIResult': +def cache_clear(connection: 'Connection', cache: Union[str, int], binary: bool = False, + query_id: Optional[int] = None) -> 'APIResult': """ Clears the cache without notifying listeners or cache writers. @@ -683,7 +794,18 @@ def cache_clear( :return: API result data object. Contains zero status on success, non-zero status and an error description otherwise. """ + return __cache_clear(connection, cache, binary, query_id) + + +async def cache_clear_async(connection: 'AioConnection', cache: Union[str, int], binary: bool = False, + query_id: Optional[int] = None) -> 'APIResult': + """ + Async version of cache_clear. + """ + return await __cache_clear(connection, cache, binary, query_id) + +def __cache_clear(connection, cache, binary, query_id): query_struct = Query( OP_CACHE_CLEAR, [ @@ -692,8 +814,8 @@ def cache_clear( ], query_id=query_id, ) - return query_struct.perform( - connection, + return query_perform( + query_struct, connection, query_params={ 'hash_code': cache_id(cache), 'flag': 1 if binary else 0, @@ -701,11 +823,8 @@ def cache_clear( ) -def cache_clear_key( - connection: 'Connection', cache: Union[str, int], key: Any, - key_hint: 'IgniteDataType' = None, - binary: bool = False, query_id: Optional[int] = None, -) -> 'APIResult': +def cache_clear_key(connection: 'Connection', cache: Union[str, int], key: Any, key_hint: 'IgniteDataType' = None, + binary: bool = False, query_id: Optional[int] = None) -> 'APIResult': """ Clears the cache key without notifying listeners or cache writers. @@ -722,7 +841,19 @@ def cache_clear_key( :return: API result data object. Contains zero status on success, non-zero status and an error description otherwise. """ + return __cache_clear_key(connection, cache, key, key_hint, binary, query_id) + + +async def cache_clear_key_async(connection: 'AioConnection', cache: Union[str, int], key: Any, + key_hint: 'IgniteDataType' = None, binary: bool = False, + query_id: Optional[int] = None) -> 'APIResult': + """ + Async version of cache_clear_key. + """ + return await __cache_clear_key(connection, cache, key, key_hint, binary, query_id) + +def __cache_clear_key(connection, cache, key, key_hint, binary, query_id): query_struct = Query( OP_CACHE_CLEAR_KEY, [ @@ -732,8 +863,8 @@ def cache_clear_key( ], query_id=query_id, ) - return query_struct.perform( - connection, + return query_perform( + query_struct, connection, query_params={ 'hash_code': cache_id(cache), 'flag': 1 if binary else 0, @@ -742,10 +873,8 @@ def cache_clear_key( ) -def cache_clear_keys( - connection: 'Connection', cache: Union[str, int], keys: list, - binary: bool = False, query_id: Optional[int] = None, -) -> 'APIResult': +def cache_clear_keys(connection: 'Connection', cache: Union[str, int], keys: Iterable, binary: bool = False, + query_id: Optional[int] = None) -> 'APIResult': """ Clears the cache keys without notifying listeners or cache writers. @@ -760,7 +889,18 @@ def cache_clear_keys( :return: API result data object. Contains zero status on success, non-zero status and an error description otherwise. """ + return __cache_clear_keys(connection, cache, keys, binary, query_id) + +async def cache_clear_keys_async(connection: 'AioConnection', cache: Union[str, int], keys: Iterable, + binary: bool = False, query_id: Optional[int] = None) -> 'APIResult': + """ + Async version of cache_clear_keys. + """ + return await __cache_clear_keys(connection, cache, keys, binary, query_id) + + +def __cache_clear_keys(connection, cache, keys, binary, query_id): query_struct = Query( OP_CACHE_CLEAR_KEYS, [ @@ -770,8 +910,8 @@ def cache_clear_keys( ], query_id=query_id, ) - return query_struct.perform( - connection, + return query_perform( + query_struct, connection, query_params={ 'hash_code': cache_id(cache), 'flag': 1 if binary else 0, @@ -780,11 +920,8 @@ def cache_clear_keys( ) -def cache_remove_key( - connection: 'Connection', cache: Union[str, int], key: Any, - key_hint: 'IgniteDataType' = None, - binary: bool = False, query_id: Optional[int] = None, -) -> 'APIResult': +def cache_remove_key(connection: 'Connection', cache: Union[str, int], key: Any, key_hint: 'IgniteDataType' = None, + binary: bool = False, query_id: Optional[int] = None) -> 'APIResult': """ Clears the cache key without notifying listeners or cache writers. @@ -802,7 +939,19 @@ def cache_remove_key( success code, or non-zero status and an error description if something has gone wrong. """ + return __cache_remove_key(connection, cache, key, key_hint, binary, query_id) + + +async def cache_remove_key_async(connection: 'AioConnection', cache: Union[str, int], key: Any, + key_hint: 'IgniteDataType' = None, binary: bool = False, + query_id: Optional[int] = None) -> 'APIResult': + """ + Async version of cache_remove_key. + """ + return await __cache_remove_key(connection, cache, key, key_hint, binary, query_id) + +def __cache_remove_key(connection, cache, key, key_hint, binary, query_id): query_struct = Query( OP_CACHE_REMOVE_KEY, [ @@ -812,8 +961,8 @@ def cache_remove_key( ], query_id=query_id, ) - result = query_struct.perform( - connection, + return query_perform( + query_struct, connection, query_params={ 'hash_code': cache_id(cache), 'flag': 1 if binary else 0, @@ -822,17 +971,13 @@ def cache_remove_key( response_config=[ ('success', Bool), ], + post_process_fun=__post_process_value_by_key('success') ) - if result.status == 0: - result.value = result.value['success'] - return result -def cache_remove_if_equals( - connection: 'Connection', cache: Union[str, int], key: Any, sample: Any, - key_hint: 'IgniteDataType' = None, sample_hint: 'IgniteDataType' = None, - binary: bool = False, query_id: Optional[int] = None, -) -> 'APIResult': +def cache_remove_if_equals(connection: 'Connection', cache: Union[str, int], key: Any, sample: Any, + key_hint: 'IgniteDataType' = None, sample_hint: 'IgniteDataType' = None, + binary: bool = False, query_id: Optional[int] = None) -> 'APIResult': """ Removes an entry with a given key if provided value is equal to actual value, notifying listeners and cache writers. @@ -854,7 +999,19 @@ def cache_remove_if_equals( success code, or non-zero status and an error description if something has gone wrong. """ + return __cache_remove_if_equals(connection, cache, key, sample, key_hint, sample_hint, binary, query_id) + + +async def cache_remove_if_equals_async( + connection: 'AioConnection', cache: Union[str, int], key: Any, sample: Any, key_hint: 'IgniteDataType' = None, + sample_hint: 'IgniteDataType' = None, binary: bool = False, query_id: Optional[int] = None) -> 'APIResult': + """ + Async version of cache_remove_if_equals. + """ + return await __cache_remove_if_equals(connection, cache, key, sample, key_hint, sample_hint, binary, query_id) + +def __cache_remove_if_equals(connection, cache, key, sample, key_hint, sample_hint, binary, query_id): query_struct = Query( OP_CACHE_REMOVE_IF_EQUALS, [ @@ -865,8 +1022,8 @@ def cache_remove_if_equals( ], query_id=query_id, ) - result = query_struct.perform( - connection, + return query_perform( + query_struct, connection, query_params={ 'hash_code': cache_id(cache), 'flag': 1 if binary else 0, @@ -876,16 +1033,12 @@ def cache_remove_if_equals( response_config=[ ('success', Bool), ], + post_process_fun=__post_process_value_by_key('success') ) - if result.status == 0: - result.value = result.value['success'] - return result -def cache_remove_keys( - connection: 'Connection', cache: Union[str, int], keys: Iterable, - binary: bool = False, query_id: Optional[int] = None, -) -> 'APIResult': +def cache_remove_keys(connection: 'Connection', cache: Union[str, int], keys: Iterable, binary: bool = False, + query_id: Optional[int] = None) -> 'APIResult': """ Removes entries with given keys, notifying listeners and cache writers. @@ -900,7 +1053,18 @@ def cache_remove_keys( :return: API result data object. Contains zero status on success, non-zero status and an error description otherwise. """ + return __cache_remove_keys(connection, cache, keys, binary, query_id) + +async def cache_remove_keys_async(connection: 'AioConnection', cache: Union[str, int], keys: Iterable, + binary: bool = False, query_id: Optional[int] = None) -> 'APIResult': + """ + Async version of cache_remove_keys. + """ + return await __cache_remove_keys(connection, cache, keys, binary, query_id) + + +def __cache_remove_keys(connection, cache, keys, binary, query_id): query_struct = Query( OP_CACHE_REMOVE_KEYS, [ @@ -910,8 +1074,8 @@ def cache_remove_keys( ], query_id=query_id, ) - return query_struct.perform( - connection, + return query_perform( + query_struct, connection, query_params={ 'hash_code': cache_id(cache), 'flag': 1 if binary else 0, @@ -920,10 +1084,8 @@ def cache_remove_keys( ) -def cache_remove_all( - connection: 'Connection', cache: Union[str, int], - binary: bool = False, query_id: Optional[int] = None, -) -> 'APIResult': +def cache_remove_all(connection: 'Connection', cache: Union[str, int], binary: bool = False, + query_id: Optional[int] = None) -> 'APIResult': """ Removes all entries from cache, notifying listeners and cache writers. @@ -937,7 +1099,18 @@ def cache_remove_all( :return: API result data object. Contains zero status on success, non-zero status and an error description otherwise. """ + return __cache_remove_all(connection, cache, binary, query_id) + + +async def cache_remove_all_async(connection: 'AioConnection', cache: Union[str, int], binary: bool = False, + query_id: Optional[int] = None) -> 'APIResult': + """ + Async version of cache_remove_all. + """ + return await __cache_remove_all(connection, cache, binary, query_id) + +def __cache_remove_all(connection, cache, binary, query_id): query_struct = Query( OP_CACHE_REMOVE_ALL, [ @@ -946,8 +1119,8 @@ def cache_remove_all( ], query_id=query_id, ) - return query_struct.perform( - connection, + return query_perform( + query_struct, connection, query_params={ 'hash_code': cache_id(cache), 'flag': 1 if binary else 0, @@ -955,10 +1128,8 @@ def cache_remove_all( ) -def cache_get_size( - connection: 'Connection', cache: Union[str, int], peek_modes: int = 0, - binary: bool = False, query_id: Optional[int] = None, -) -> 'APIResult': +def cache_get_size(connection: 'Connection', cache: Union[str, int], peek_modes: Union[int, list, tuple] = 0, + binary: bool = False, query_id: Optional[int] = None) -> 'APIResult': """ Gets the number of entries in cache. @@ -976,6 +1147,16 @@ def cache_get_size( cache entries on success, non-zero status and an error description otherwise. """ + return __cache_get_size(connection, cache, peek_modes, binary, query_id) + + +async def cache_get_size_async(connection: 'AioConnection', cache: Union[str, int], + peek_modes: Union[int, list, tuple] = 0, binary: bool = False, + query_id: Optional[int] = None) -> 'APIResult': + return await __cache_get_size(connection, cache, peek_modes, binary, query_id) + + +def __cache_get_size(connection, cache, peek_modes, binary, query_id): if not isinstance(peek_modes, (list, tuple)): peek_modes = [peek_modes] if peek_modes else [] @@ -988,8 +1169,8 @@ def cache_get_size( ], query_id=query_id, ) - result = query_struct.perform( - connection, + return query_perform( + query_struct, connection, query_params={ 'hash_code': cache_id(cache), 'flag': 1 if binary else 0, @@ -998,21 +1179,17 @@ def cache_get_size( response_config=[ ('count', Long), ], + post_process_fun=__post_process_value_by_key('count') ) - if result.status == 0: - result.value = result.value['count'] - return result -def cache_local_peek( - conn: 'Connection', cache: Union[str, int], - key: Any, key_hint: 'IgniteDataType' = None, peek_modes: int = 0, - binary: bool = False, query_id: Optional[int] = None, -) -> 'APIResult': +def cache_local_peek(conn: 'Connection', cache: Union[str, int], key: Any, key_hint: 'IgniteDataType' = None, + peek_modes: Union[int, list, tuple] = 0, binary: bool = False, + query_id: Optional[int] = None) -> 'APIResult': """ Peeks at in-memory cached value using default optional peek mode. - This method will not load value from any persistent store or from a remote + This method will not load value from any cache store or from a remote node. :param conn: connection: connection to Ignite server, @@ -1031,6 +1208,19 @@ def cache_local_peek( :return: API result data object. Contains zero status and a peeked value (null if not found). """ + return __cache_local_peek(conn, cache, key, key_hint, peek_modes, binary, query_id) + + +async def cache_local_peek_async( + conn: 'AioConnection', cache: Union[str, int], key: Any, key_hint: 'IgniteDataType' = None, + peek_modes: Union[int, list, tuple] = 0, binary: bool = False, query_id: Optional[int] = None) -> 'APIResult': + """ + Async version of cache_local_peek. + """ + return await __cache_local_peek(conn, cache, key, key_hint, peek_modes, binary, query_id) + + +def __cache_local_peek(conn, cache, key, key_hint, peek_modes, binary, query_id): if not isinstance(peek_modes, (list, tuple)): peek_modes = [peek_modes] if peek_modes else [] @@ -1044,8 +1234,8 @@ def cache_local_peek( ], query_id=query_id, ) - result = query_struct.perform( - conn, + return query_perform( + query_struct, conn, query_params={ 'hash_code': cache_id(cache), 'flag': 1 if binary else 0, @@ -1055,8 +1245,14 @@ def cache_local_peek( response_config=[ ('value', AnyDataObject), ], + post_process_fun=__post_process_value_by_key('value') ) - if result.status != 0: + + +def __post_process_value_by_key(key): + def internal(result): + if result.status == 0: + result.value = result.value[key] + return result - result.value = result.value['value'] - return result + return internal diff --git a/pyignite/api/result.py b/pyignite/api/result.py index f60a437..f134be9 100644 --- a/pyignite/api/result.py +++ b/pyignite/api/result.py @@ -32,7 +32,7 @@ class APIResult: message = 'Success' value = None - def __init__(self, response: 'Response'): + def __init__(self, response): self.status = getattr(response, 'status_code', OP_SUCCESS) self.query_id = response.query_id if hasattr(response, 'error_message'): diff --git a/pyignite/api/sql.py b/pyignite/api/sql.py index dc470d1..b10cc7d 100644 --- a/pyignite/api/sql.py +++ b/pyignite/api/sql.py @@ -15,23 +15,21 @@ from typing import Union -from pyignite.constants import * -from pyignite.datatypes import ( - AnyDataArray, AnyDataObject, Bool, Byte, Int, Long, Map, Null, String, - StructArray, -) +from pyignite.connection import AioConnection, Connection +from pyignite.datatypes import AnyDataArray, AnyDataObject, Bool, Byte, Int, Long, Map, Null, String, StructArray from pyignite.datatypes.sql import StatementType -from pyignite.queries import Query -from pyignite.queries.op_codes import * +from pyignite.queries import Query, query_perform +from pyignite.queries.op_codes import ( + OP_QUERY_SCAN, OP_QUERY_SCAN_CURSOR_GET_PAGE, OP_QUERY_SQL, OP_QUERY_SQL_CURSOR_GET_PAGE, OP_QUERY_SQL_FIELDS, + OP_QUERY_SQL_FIELDS_CURSOR_GET_PAGE, OP_RESOURCE_CLOSE +) from pyignite.utils import cache_id, deprecated from .result import APIResult +from ..queries.response import SQLResponse -def scan( - conn: 'Connection', cache: Union[str, int], page_size: int, - partitions: int = -1, local: bool = False, binary: bool = False, - query_id: int = None, -) -> APIResult: +def scan(conn: 'Connection', cache: Union[str, int], page_size: int, partitions: int = -1, local: bool = False, + binary: bool = False, query_id: int = None) -> APIResult: """ Performs scan query. @@ -58,7 +56,24 @@ def scan( * `more`: bool, True if more data is available for subsequent ‘scan_cursor_get_page’ calls. """ + return __scan(conn, cache, page_size, partitions, local, binary, query_id) + + +async def scan_async(conn: 'AioConnection', cache: Union[str, int], page_size: int, partitions: int = -1, + local: bool = False, binary: bool = False, query_id: int = None) -> APIResult: + """ + Async version of scan. + """ + return await __scan(conn, cache, page_size, partitions, local, binary, query_id) + + +def __query_result_post_process(result): + if result.status == 0: + result.value = dict(result.value) + return result + +def __scan(conn, cache, page_size, partitions, local, binary, query_id): query_struct = Query( OP_QUERY_SCAN, [ @@ -71,8 +86,8 @@ def scan( ], query_id=query_id, ) - result = query_struct.perform( - conn, + return query_perform( + query_struct, conn, query_params={ 'hash_code': cache_id(cache), 'flag': 1 if binary else 0, @@ -86,15 +101,11 @@ def scan( ('data', Map), ('more', Bool), ], + post_process_fun=__query_result_post_process ) - if result.status == 0: - result.value = dict(result.value) - return result -def scan_cursor_get_page( - conn: 'Connection', cursor: int, query_id: int = None, -) -> APIResult: +def scan_cursor_get_page(conn: 'Connection', cursor: int, query_id: int = None) -> APIResult: """ Fetches the next scan query cursor page by cursor ID that is obtained from `scan` function. @@ -114,7 +125,14 @@ def scan_cursor_get_page( * `more`: bool, True if more data is available for subsequent ‘scan_cursor_get_page’ calls. """ + return __scan_cursor_get_page(conn, cursor, query_id) + +async def scan_cursor_get_page_async(conn: 'AioConnection', cursor: int, query_id: int = None) -> APIResult: + return await __scan_cursor_get_page(conn, cursor, query_id) + + +def __scan_cursor_get_page(conn, cursor, query_id): query_struct = Query( OP_QUERY_SCAN_CURSOR_GET_PAGE, [ @@ -122,8 +140,8 @@ def scan_cursor_get_page( ], query_id=query_id, ) - result = query_struct.perform( - conn, + return query_perform( + query_struct, conn, query_params={ 'cursor': cursor, }, @@ -131,10 +149,8 @@ def scan_cursor_get_page( ('data', Map), ('more', Bool), ], + post_process_fun=__query_result_post_process ) - if result.status == 0: - result.value = dict(result.value) - return result @deprecated(version='1.2.0', reason="This API is deprecated and will be removed in the following major release. " @@ -322,6 +338,31 @@ def sql_fields( * `more`: bool, True if more data is available for subsequent ‘sql_fields_cursor_get_page’ calls. """ + return __sql_fields(conn, cache, query_str, page_size, query_args, schema, statement_type, distributed_joins, + local, replicated_only, enforce_join_order, collocated, lazy, include_field_names, max_rows, + timeout, binary, query_id) + + +async def sql_fields_async( + conn: 'AioConnection', cache: Union[str, int], + query_str: str, page_size: int, query_args=None, schema: str = None, + statement_type: int = StatementType.ANY, distributed_joins: bool = False, + local: bool = False, replicated_only: bool = False, + enforce_join_order: bool = False, collocated: bool = False, + lazy: bool = False, include_field_names: bool = False, max_rows: int = -1, + timeout: int = 0, binary: bool = False, query_id: int = None +) -> APIResult: + """ + Async version of sql_fields. + """ + return await __sql_fields(conn, cache, query_str, page_size, query_args, schema, statement_type, distributed_joins, + local, replicated_only, enforce_join_order, collocated, lazy, include_field_names, + max_rows, timeout, binary, query_id) + + +def __sql_fields(conn, cache, query_str, page_size, query_args, schema, statement_type, distributed_joins, local, + replicated_only, enforce_join_order, collocated, lazy, include_field_names, max_rows, timeout, + binary, query_id): if query_args is None: query_args = [] @@ -346,10 +387,11 @@ def sql_fields( ('include_field_names', Bool), ], query_id=query_id, + response_type=SQLResponse ) - return query_struct.perform( - conn, + return query_perform( + query_struct, conn, query_params={ 'hash_code': cache_id(cache), 'flag': 1 if binary else 0, @@ -368,15 +410,12 @@ def sql_fields( 'timeout': timeout, 'include_field_names': include_field_names, }, - sql=True, include_field_names=include_field_names, has_cursor=True, ) -def sql_fields_cursor_get_page( - conn: 'Connection', cursor: int, field_count: int, query_id: int = None, -) -> APIResult: +def sql_fields_cursor_get_page(conn: 'Connection', cursor: int, field_count: int, query_id: int = None) -> APIResult: """ Retrieves the next query result page by cursor ID from `sql_fields`. @@ -396,7 +435,18 @@ def sql_fields_cursor_get_page( * `more`: bool, True if more data is available for subsequent ‘sql_fields_cursor_get_page’ calls. """ + return __sql_fields_cursor_get_page(conn, cursor, field_count, query_id) + + +async def sql_fields_cursor_get_page_async(conn: 'AioConnection', cursor: int, field_count: int, + query_id: int = None) -> APIResult: + """ + Async version sql_fields_cursor_get_page. + """ + return await __sql_fields_cursor_get_page(conn, cursor, field_count, query_id) + +def __sql_fields_cursor_get_page(conn, cursor, field_count, query_id): query_struct = Query( OP_QUERY_SQL_FIELDS_CURSOR_GET_PAGE, [ @@ -404,16 +454,20 @@ def sql_fields_cursor_get_page( ], query_id=query_id, ) - result = query_struct.perform( - conn, + return query_perform( + query_struct, conn, query_params={ 'cursor': cursor, }, response_config=[ ('data', StructArray([(f'field_{i}', AnyDataObject) for i in range(field_count)])), ('more', Bool), - ] + ], + post_process_fun=__post_process_sql_fields_cursor ) + + +def __post_process_sql_fields_cursor(result): if result.status != 0: return result @@ -427,9 +481,7 @@ def sql_fields_cursor_get_page( return result -def resource_close( - conn: 'Connection', cursor: int, query_id: int = None -) -> APIResult: +def resource_close(conn: 'Connection', cursor: int, query_id: int = None) -> APIResult: """ Closes a resource, such as query cursor. @@ -441,7 +493,14 @@ def resource_close( :return: API result data object. Contains zero status on success, non-zero status and an error description otherwise. """ + return __resource_close(conn, cursor, query_id) + + +async def resource_close_async(conn: 'AioConnection', cursor: int, query_id: int = None) -> APIResult: + return await __resource_close(conn, cursor, query_id) + +def __resource_close(conn, cursor, query_id): query_struct = Query( OP_RESOURCE_CLOSE, [ @@ -449,9 +508,9 @@ def resource_close( ], query_id=query_id, ) - return query_struct.perform( - conn, + return query_perform( + query_struct, conn, query_params={ 'cursor': cursor, - }, + } ) diff --git a/pyignite/binary.py b/pyignite/binary.py index da62bb5..4e34267 100644 --- a/pyignite/binary.py +++ b/pyignite/binary.py @@ -27,15 +27,22 @@ from collections import OrderedDict import ctypes +from io import SEEK_CUR from typing import Any import attr -from pyignite.constants import * -from .datatypes import * +from .constants import PROTOCOL_BYTE_ORDER +from .datatypes import ( + Null, ByteObject, ShortObject, IntObject, LongObject, FloatObject, DoubleObject, CharObject, BoolObject, UUIDObject, + DateObject, TimestampObject, TimeObject, EnumObject, BinaryEnumObject, ByteArrayObject, ShortArrayObject, + IntArrayObject, LongArrayObject, FloatArrayObject, DoubleArrayObject, CharArrayObject, BoolArrayObject, + UUIDArrayObject, DateArrayObject, TimestampArrayObject, TimeArrayObject, EnumArrayObject, String, StringArrayObject, + DecimalObject, DecimalArrayObject, ObjectArrayObject, CollectionObject, MapObject, BinaryObject, WrappedDataObject +) from .datatypes.base import IgniteDataTypeProps from .exceptions import ParseError -from .utils import entity_id, hashcode, schema_id +from .utils import entity_id, schema_id ALLOWED_FIELD_TYPES = [ @@ -69,12 +76,14 @@ def schema_id(self) -> int: def __new__(cls, *args, **kwargs) -> Any: # allow all items in Binary Object schema to be populated as optional # arguments to `__init__()` with sensible defaults. - attributes = {} - for k, v in cls.schema.items(): - attributes[k] = attr.ib(type=getattr(v, 'pythonic', type(None)), default=getattr(v, 'default', None)) - - attributes.update({'version': attr.ib(type=int, default=1)}) - cls = attr.s(cls, these=attributes) + if not attr.has(cls): + attributes = { + k: attr.ib(type=getattr(v, 'pythonic', type(None)), default=getattr(v, 'default', None)) + for k, v in cls.schema.items() + } + + attributes.update({'version': attr.ib(type=int, default=1)}) + cls = attr.s(cls, these=attributes) # skip parameters return super().__new__(cls) @@ -99,7 +108,7 @@ def __new__( """ Sort out class creation arguments. """ result = super().__new__( - mcs, name, (GenericObjectProps, )+base_classes, namespace + mcs, name, (GenericObjectProps, ) + base_classes, namespace ) def _from_python(self, stream, save_to_buf=False): @@ -111,10 +120,37 @@ def _from_python(self, stream, save_to_buf=False): :param stream: BinaryStream :param save_to_buf: Optional. If True, save serialized data to buffer. """ + initial_pos = stream.tell() + header, header_class = write_header(self, stream) + + offsets = [ctypes.sizeof(header_class)] + schema_items = list(self.schema.items()) + for field_name, field_type in schema_items: + val = getattr(self, field_name, getattr(field_type, 'default', None)) + field_start_pos = stream.tell() + field_type.from_python(stream, val) + offsets.append(max(offsets) + stream.tell() - field_start_pos) + + write_footer(self, stream, header, header_class, schema_items, offsets, initial_pos, save_to_buf) - compact_footer = stream.compact_footer + async def _from_python_async(self, stream, save_to_buf=False): + """ + Async version of _from_python + """ + initial_pos = stream.tell() + header, header_class = write_header(self, stream) - # prepare header + offsets = [ctypes.sizeof(header_class)] + schema_items = list(self.schema.items()) + for field_name, field_type in schema_items: + val = getattr(self, field_name, getattr(field_type, 'default', None)) + field_start_pos = stream.tell() + await field_type.from_python_async(stream, val) + offsets.append(max(offsets) + stream.tell() - field_start_pos) + + write_footer(self, stream, header, header_class, schema_items, offsets, initial_pos, save_to_buf) + + def write_header(obj, stream): header_class = BinaryObject.build_header() header = header_class() header.type_code = int.from_bytes( @@ -122,36 +158,30 @@ def _from_python(self, stream, save_to_buf=False): byteorder=PROTOCOL_BYTE_ORDER ) header.flags = BinaryObject.USER_TYPE | BinaryObject.HAS_SCHEMA - if compact_footer: + if stream.compact_footer: header.flags |= BinaryObject.COMPACT_FOOTER - header.version = self.version - header.type_id = self.type_id - header.schema_id = self.schema_id + header.version = obj.version + header.type_id = obj.type_id + header.schema_id = obj.schema_id - header_len = ctypes.sizeof(header_class) - initial_pos = stream.tell() + stream.seek(ctypes.sizeof(header_class), SEEK_CUR) - # create fields and calculate offsets - offsets = [ctypes.sizeof(header_class)] - schema_items = list(self.schema.items()) - - stream.seek(initial_pos + header_len) - for field_name, field_type in schema_items: - val = getattr(self, field_name, getattr(field_type, 'default', None)) - field_start_pos = stream.tell() - field_type.from_python(stream, val) - offsets.append(max(offsets) + stream.tell() - field_start_pos) + return header, header_class + def write_footer(obj, stream, header, header_class, schema_items, offsets, initial_pos, save_to_buf): offsets = offsets[:-1] + header_len = ctypes.sizeof(header_class) # create footer if max(offsets, default=0) < 255: header.flags |= BinaryObject.OFFSET_ONE_BYTE elif max(offsets) < 65535: header.flags |= BinaryObject.OFFSET_TWO_BYTES + schema_class = BinaryObject.schema_type(header.flags) * len(offsets) schema = schema_class() - if compact_footer: + + if stream.compact_footer: for i, offset in enumerate(offsets): schema[i] = offset else: @@ -171,8 +201,8 @@ def _from_python(self, stream, save_to_buf=False): stream.write(schema) if save_to_buf: - self._buffer = bytes(stream.mem_view(initial_pos, stream.tell() - initial_pos)) - self._hashcode = header.hash_code + obj._buffer = bytes(stream.mem_view(initial_pos, stream.tell() - initial_pos)) + obj._hashcode = header.hash_code def _setattr(self, attr_name: str, attr_value: Any): # reset binary representation, if any field is changed @@ -184,6 +214,7 @@ def _setattr(self, attr_name: str, attr_value: Any): super(result, self).__setattr__(attr_name, attr_value) setattr(result, _from_python.__name__, _from_python) + setattr(result, _from_python_async.__name__, _from_python_async) setattr(result, '__setattr__', _setattr) setattr(result, '_buffer', None) setattr(result, '_hashcode', None) diff --git a/pyignite/cache.py b/pyignite/cache.py index a91a3cf..5fba6fb 100644 --- a/pyignite/cache.py +++ b/pyignite/cache.py @@ -16,54 +16,145 @@ import time from typing import Any, Dict, Iterable, Optional, Tuple, Union -from .constants import * -from .binary import GenericObjectMeta, unwrap_binary +from .constants import AFFINITY_RETRIES, AFFINITY_DELAY +from .binary import GenericObjectMeta from .datatypes import prop_codes from .datatypes.internal import AnyDataObject -from .exceptions import ( - CacheCreationError, CacheError, ParameterError, SQLError, - connection_errors, -) -from .utils import ( - cache_id, get_field_by_id, is_wrapped, - status_to_exception, unsigned -) +from .exceptions import CacheCreationError, CacheError, ParameterError, SQLError, connection_errors +from .utils import cache_id, get_field_by_id, status_to_exception, unsigned from .api.cache_config import ( - cache_create, cache_create_with_config, - cache_get_or_create, cache_get_or_create_with_config, - cache_destroy, cache_get_configuration, + cache_create, cache_create_with_config, cache_get_or_create, cache_get_or_create_with_config, cache_destroy, + cache_get_configuration ) from .api.key_value import ( - cache_get, cache_put, cache_get_all, cache_put_all, cache_replace, - cache_clear, cache_clear_key, cache_clear_keys, - cache_contains_key, cache_contains_keys, - cache_get_and_put, cache_get_and_put_if_absent, cache_put_if_absent, - cache_get_and_remove, cache_get_and_replace, - cache_remove_key, cache_remove_keys, cache_remove_all, - cache_remove_if_equals, cache_replace_if_equals, cache_get_size, + cache_get, cache_put, cache_get_all, cache_put_all, cache_replace, cache_clear, cache_clear_key, cache_clear_keys, + cache_contains_key, cache_contains_keys, cache_get_and_put, cache_get_and_put_if_absent, cache_put_if_absent, + cache_get_and_remove, cache_get_and_replace, cache_remove_key, cache_remove_keys, cache_remove_all, + cache_remove_if_equals, cache_replace_if_equals, cache_get_size ) -from .api.sql import scan, scan_cursor_get_page, sql, sql_cursor_get_page +from .cursors import ScanCursor, SqlCursor from .api.affinity import cache_get_node_partitions - PROP_CODES = set([ getattr(prop_codes, x) for x in dir(prop_codes) if x.startswith('PROP_') ]) -CACHE_CREATE_FUNCS = { - True: { - True: cache_get_or_create_with_config, - False: cache_create_with_config, - }, - False: { - True: cache_get_or_create, - False: cache_create, - }, -} - - -class Cache: + + +def get_cache(client: 'Client', settings: Union[str, dict]) -> 'Cache': + name, settings = __parse_settings(settings) + if settings: + raise ParameterError('Only cache name allowed as a parameter') + + return Cache(client, name) + + +def create_cache(client: 'Client', settings: Union[str, dict]) -> 'Cache': + name, settings = __parse_settings(settings) + + conn = client.random_node + if settings: + result = cache_create_with_config(conn, settings) + else: + result = cache_create(conn, name) + + if result.status != 0: + raise CacheCreationError(result.message) + + return Cache(client, name) + + +def get_or_create_cache(client: 'Client', settings: Union[str, dict]) -> 'Cache': + name, settings = __parse_settings(settings) + + conn = client.random_node + if settings: + result = cache_get_or_create_with_config(conn, settings) + else: + result = cache_get_or_create(conn, name) + + if result.status != 0: + raise CacheCreationError(result.message) + + return Cache(client, name) + + +def __parse_settings(settings: Union[str, dict]) -> Tuple[Optional[str], Optional[dict]]: + if isinstance(settings, str): + return settings, None + elif isinstance(settings, dict) and prop_codes.PROP_NAME in settings: + name = settings[prop_codes.PROP_NAME] + if len(settings) == 1: + return name, None + + if not set(settings).issubset(PROP_CODES): + raise ParameterError('One or more settings was not recognized') + + return name, settings + else: + raise ParameterError('You should supply at least cache name') + + +class BaseCacheMixin: + def _get_affinity_key(self, key, key_hint=None): + if key_hint is None: + key_hint = AnyDataObject.map_python_type(key) + + if self.affinity.get('is_applicable'): + config = self.affinity.get('cache_config') + if config: + affinity_key_id = config.get(key_hint.type_id) + + if affinity_key_id and isinstance(key, GenericObjectMeta): + return get_field_by_id(key, affinity_key_id) + + return key, key_hint + + def _update_affinity(self, full_affinity): + self.affinity['version'] = full_affinity['version'] + + full_mapping = full_affinity.get('partition_mapping') + if full_mapping and self.cache_id in full_mapping: + self.affinity.update(full_mapping[self.cache_id]) + + def _get_node_by_hashcode(self, hashcode, parts): + """ + Get node by key hashcode. Calculate partition and return node on that it is primary. + (algorithm is taken from `RendezvousAffinityFunction.java`) + """ + + # calculate partition for key or affinity key + # (algorithm is taken from `RendezvousAffinityFunction.java`) + mask = parts - 1 + + if parts & mask == 0: + part = (hashcode ^ (unsigned(hashcode) >> 16)) & mask + else: + part = abs(hashcode // parts) + + assert 0 <= part < parts, 'Partition calculation has failed' + + node_mapping = self.affinity.get('node_mapping') + if not node_mapping: + return None + + node_uuid, best_conn = None, None + for u, p in node_mapping.items(): + if part in p: + node_uuid = u + break + + if node_uuid: + for n in self.client._nodes: + if n.uuid == node_uuid: + best_conn = n + break + if best_conn and best_conn.alive: + return best_conn + + +class Cache(BaseCacheMixin): """ Ignite cache abstraction. Users should never use this class directly, but construct its instances with @@ -73,77 +164,18 @@ class Cache: :ref:`this example ` on how to do it. """ - affinity = None - _cache_id = None - _name = None - _client = None - _settings = None - - @staticmethod - def _validate_settings( - settings: Union[str, dict] = None, get_only: bool = False, - ): - if any([ - not settings, - type(settings) not in (str, dict), - type(settings) is dict and prop_codes.PROP_NAME not in settings, - ]): - raise ParameterError('You should supply at least cache name') - - if all([ - type(settings) is dict, - not set(settings).issubset(PROP_CODES), - ]): - raise ParameterError('One or more settings was not recognized') - - if get_only and type(settings) is dict and len(settings) != 1: - raise ParameterError('Only cache name allowed as a parameter') - - def __init__( - self, client: 'Client', settings: Union[str, dict] = None, - with_get: bool = False, get_only: bool = False, - ): + def __init__(self, client: 'Client', name: str): """ - Initialize cache object. + Initialize cache object. For internal use. :param client: Ignite client, - :param settings: cache settings. Can be a string (cache name) or a dict - of cache properties and their values. In this case PROP_NAME is - mandatory, - :param with_get: (optional) do not raise exception, if the cache - is already exists. Defaults to False, - :param get_only: (optional) do not communicate with Ignite server - at all, only create Cache instance. Defaults to False. + :param name: Cache name. """ self._client = client - self._validate_settings(settings) - if type(settings) == str: - self._name = settings - else: - self._name = settings[prop_codes.PROP_NAME] - - if not get_only: - func = CACHE_CREATE_FUNCS[type(settings) is dict][with_get] - result = func(client.random_node, settings) - if result.status != 0: - raise CacheCreationError(result.message) - + self._name = name + self._settings = None self._cache_id = cache_id(self._name) - self.affinity = { - 'version': (0, 0), - } - - def get_protocol_version(self) -> Optional[Tuple]: - """ - Returns the tuple of major, minor, and revision numbers of the used - thin protocol version, or None, if no connection to the Ignite cluster - was not yet established. - - This method is not a part of the public API. Unless you wish to - extend the `pyignite` capabilities (with additional testing, logging, - examining connections, et c.) you probably should not use it. - """ - return self.client.protocol_version + self.affinity = {'version': (0, 0)} @property def settings(self) -> Optional[dict]: @@ -197,18 +229,6 @@ def cache_id(self) -> int: """ return self._cache_id - def _process_binary(self, value: Any) -> Any: - """ - Detects and recursively unwraps Binary Object. - - :param value: anything that could be a Binary Object, - :return: the result of the Binary Object unwrapping with all other data - left intact. - """ - if is_wrapped(value): - return unwrap_binary(self._client, value) - return value - @status_to_exception(CacheError) def destroy(self): """ @@ -234,9 +254,7 @@ def _get_affinity(self, conn: 'Connection') -> Dict: return result - def get_best_node( - self, key: Any = None, key_hint: 'IgniteDataType' = None, - ) -> 'Connection': + def get_best_node(self, key: Any = None, key_hint: 'IgniteDataType' = None) -> 'Connection': """ Returns the node from the list of the nodes, opened by client, that most probably contains the needed key-value pair. See IEP-23. @@ -253,14 +271,11 @@ def get_best_node( conn = self._client.random_node if self.client.partition_aware and key is not None: - if key_hint is None: - key_hint = AnyDataObject.map_python_type(key) - if self.affinity['version'] < self._client.affinity_version: # update partition mapping while True: try: - self.affinity = self._get_affinity(conn) + full_affinity = self._get_affinity(conn) break except connection_errors: # retry if connection failed @@ -270,68 +285,23 @@ def get_best_node( # server did not create mapping in time return conn - # flatten it a bit - try: - self.affinity.update(self.affinity['partition_mapping'][0]) - except IndexError: - return conn - del self.affinity['partition_mapping'] - - # calculate the number of partitions - parts = 0 - if 'node_mapping' in self.affinity: - for p in self.affinity['node_mapping'].values(): - parts += len(p) - - self.affinity['number_of_partitions'] = parts + self._update_affinity(full_affinity) for conn in self.client._nodes: if not conn.alive: conn.reconnect() - else: - # get number of partitions - parts = self.affinity.get('number_of_partitions') + + parts = self.affinity.get('number_of_partitions') if not parts: return conn - if self.affinity['is_applicable']: - affinity_key_id = self.affinity['cache_config'].get( - key_hint.type_id, - None - ) - if affinity_key_id and isinstance(key, GenericObjectMeta): - key, key_hint = get_field_by_id(key, affinity_key_id) + key, key_hint = self._get_affinity_key(key, key_hint) + hashcode = key_hint.hashcode(key, self._client) - # calculate partition for key or affinity key - # (algorithm is taken from `RendezvousAffinityFunction.java`) - base_value = key_hint.hashcode(key, self._client) - mask = parts - 1 - - if parts & mask == 0: - part = (base_value ^ (unsigned(base_value) >> 16)) & mask - else: - part = abs(base_value // parts) - - assert 0 <= part < parts, 'Partition calculation has failed' - - # search for connection - try: - node_uuid, best_conn = None, None - for u, p in self.affinity['node_mapping'].items(): - if part in p: - node_uuid = u - break - - if node_uuid: - for n in conn.client._nodes: - if n.uuid == node_uuid: - best_conn = n - break - if best_conn and best_conn.alive: - conn = best_conn - except KeyError: - pass + best_node = self._get_node_by_hashcode(hashcode, parts) + if best_node: + return best_node return conn @@ -354,12 +324,12 @@ def get(self, key, key_hint: object = None) -> Any: key, key_hint=key_hint ) - result.value = self._process_binary(result.value) + result.value = self.client.unwrap_binary(result.value) return result @status_to_exception(CacheError) def put( - self, key, value, key_hint: object = None, value_hint: object = None + self, key, value, key_hint: object = None, value_hint: object = None ): """ Puts a value with a given key to cache (overwriting existing value @@ -392,7 +362,7 @@ def get_all(self, keys: list) -> list: result = cache_get_all(self.get_best_node(), self._cache_id, keys) if result.value: for key, value in result.value.items(): - result.value[key] = self._process_binary(value) + result.value[key] = self.client.unwrap_binary(value) return result @status_to_exception(CacheError) @@ -409,7 +379,7 @@ def put_all(self, pairs: dict): @status_to_exception(CacheError) def replace( - self, key, value, key_hint: object = None, value_hint: object = None + self, key, value, key_hint: object = None, value_hint: object = None ): """ Puts a value with a given key to cache only if the key already exist. @@ -429,7 +399,7 @@ def replace( self._cache_id, key, value, key_hint=key_hint, value_hint=value_hint ) - result.value = self._process_binary(result.value) + result.value = self.client.unwrap_binary(result.value) return result @status_to_exception(CacheError) @@ -465,6 +435,16 @@ def clear_key(self, key, key_hint: object = None): key_hint=key_hint ) + @status_to_exception(CacheError) + def clear_keys(self, keys: Iterable): + """ + Clears the cache key without notifying listeners or cache writers. + + :param keys: a list of keys or (key, type hint) tuples + """ + + return cache_clear_keys(self.get_best_node(), self._cache_id, keys) + @status_to_exception(CacheError) def contains_key(self, key, key_hint=None) -> bool: """ @@ -493,7 +473,7 @@ def contains_keys(self, keys: Iterable) -> bool: :param keys: a list of keys or (key, type hint) tuples, :return: boolean `True` when all keys are present, `False` otherwise. """ - return cache_contains_keys(self._client, self._cache_id, keys) + return cache_contains_keys(self.get_best_node(), self._cache_id, keys) @status_to_exception(CacheError) def get_and_put(self, key, value, key_hint=None, value_hint=None) -> Any: @@ -518,12 +498,12 @@ def get_and_put(self, key, value, key_hint=None, value_hint=None) -> Any: key, value, key_hint, value_hint ) - result.value = self._process_binary(result.value) + result.value = self.client.unwrap_binary(result.value) return result @status_to_exception(CacheError) def get_and_put_if_absent( - self, key, value, key_hint=None, value_hint=None + self, key, value, key_hint=None, value_hint=None ): """ Puts a value with a given key to cache only if the key does not @@ -546,7 +526,7 @@ def get_and_put_if_absent( key, value, key_hint, value_hint ) - result.value = self._process_binary(result.value) + result.value = self.client.unwrap_binary(result.value) return result @status_to_exception(CacheError) @@ -591,12 +571,12 @@ def get_and_remove(self, key, key_hint=None) -> Any: key, key_hint ) - result.value = self._process_binary(result.value) + result.value = self.client.unwrap_binary(result.value) return result @status_to_exception(CacheError) def get_and_replace( - self, key, value, key_hint=None, value_hint=None + self, key, value, key_hint=None, value_hint=None ) -> Any: """ Puts a value with a given key to cache, returning previous value @@ -620,7 +600,7 @@ def get_and_replace( key, value, key_hint, value_hint ) - result.value = self._process_binary(result.value) + result.value = self.client.unwrap_binary(result.value) return result @status_to_exception(CacheError) @@ -683,8 +663,8 @@ def remove_if_equals(self, key, sample, key_hint=None, sample_hint=None): @status_to_exception(CacheError) def replace_if_equals( - self, key, sample, value, - key_hint=None, sample_hint=None, value_hint=None + self, key, sample, value, + key_hint=None, sample_hint=None, value_hint=None ) -> Any: """ Puts a value with a given key to cache only if the key already exists @@ -710,7 +690,7 @@ def replace_if_equals( key, sample, value, key_hint, sample_hint, value_hint ) - result.value = self._process_binary(result.value) + result.value = self.client.unwrap_binary(result.value) return result @status_to_exception(CacheError) @@ -727,9 +707,7 @@ def get_size(self, peek_modes=0): self.get_best_node(), self._cache_id, peek_modes ) - def scan( - self, page_size: int = 1, partitions: int = -1, local: bool = False - ): + def scan(self, page_size: int = 1, partitions: int = -1, local: bool = False): """ Returns all key-value pairs from the cache, similar to `get_all`, but with internal pagination, which is slower, but safer. @@ -740,40 +718,14 @@ def scan( (negative to query entire cache), :param local: (optional) pass True if this query should be executed on local node only. Defaults to False, - :return: generator with key-value pairs. + :return: Scan query cursor. """ - node = self.get_best_node() - - result = scan( - node, - self._cache_id, - page_size, - partitions, - local - ) - if result.status != 0: - raise CacheError(result.message) - - cursor = result.value['cursor'] - for k, v in result.value['data'].items(): - k = self._process_binary(k) - v = self._process_binary(v) - yield k, v - - while result.value['more']: - result = scan_cursor_get_page(node, cursor) - if result.status != 0: - raise CacheError(result.message) - - for k, v in result.value['data'].items(): - k = self._process_binary(k) - v = self._process_binary(v) - yield k, v + return ScanCursor(self.client, self._cache_id, page_size, partitions, local) def select_row( - self, query_str: str, page_size: int = 1, - query_args: Optional[list] = None, distributed_joins: bool = False, - replicated_only: bool = False, local: bool = False, timeout: int = 0 + self, query_str: str, page_size: int = 1, + query_args: Optional[list] = None, distributed_joins: bool = False, + replicated_only: bool = False, local: bool = False, timeout: int = 0 ): """ Executes a simplified SQL SELECT query over data stored in the cache. @@ -791,46 +743,13 @@ def select_row( on local node only. Defaults to False, :param timeout: (optional) non-negative timeout value in ms. Zero disables timeout (default), - :return: generator with key-value pairs. - """ - node = self.get_best_node() - - def generate_result(value): - cursor = value['cursor'] - more = value['more'] - for k, v in value['data'].items(): - k = self._process_binary(k) - v = self._process_binary(v) - yield k, v - - while more: - inner_result = sql_cursor_get_page(node, cursor) - if result.status != 0: - raise SQLError(result.message) - more = inner_result.value['more'] - for k, v in inner_result.value['data'].items(): - k = self._process_binary(k) - v = self._process_binary(v) - yield k, v - + :return: Sql cursor. + """ type_name = self.settings[ prop_codes.PROP_QUERY_ENTITIES ][0]['value_type_name'] if not type_name: raise SQLError('Value type is unknown') - result = sql( - node, - self._cache_id, - type_name, - query_str, - page_size, - query_args, - distributed_joins, - replicated_only, - local, - timeout - ) - if result.status != 0: - raise SQLError(result.message) - return generate_result(result.value) + return SqlCursor(self.client, self._cache_id, type_name, query_str, page_size, query_args, + distributed_joins, replicated_only, local, timeout) diff --git a/pyignite/client.py b/pyignite/client.py index 9416474..e4eef6a 100644 --- a/pyignite/client.py +++ b/pyignite/client.py @@ -44,22 +44,20 @@ import random import re from itertools import chain -from typing import Dict, Iterable, List, Optional, Tuple, Type, Union +from typing import Iterable, Type, Union, Any from .api.binary import get_binary_type, put_binary_type from .api.cache_config import cache_get_names -from .api.sql import sql_fields, sql_fields_cursor_get_page -from .cache import Cache +from .cursors import SqlFieldsCursor +from .cache import Cache, create_cache, get_cache, get_or_create_cache from .connection import Connection -from .constants import * +from .constants import IGNITE_DEFAULT_HOST, IGNITE_DEFAULT_PORT, PROTOCOL_BYTE_ORDER from .datatypes import BinaryObject from .datatypes.internal import tc_map -from .exceptions import ( - BinaryTypeError, CacheError, ReconnectError, SQLError, connection_errors, -) +from .exceptions import BinaryTypeError, CacheError, ReconnectError, connection_errors +from .stream import BinaryStream, READ_BACKWARD from .utils import ( - cache_id, capitalize, entity_id, schema_id, process_delimiter, - status_to_exception, is_iterable, + cache_id, capitalize, entity_id, schema_id, process_delimiter, status_to_exception, is_iterable, is_wrapped ) from .binary import GenericObjectMeta @@ -67,58 +65,24 @@ __all__ = ['Client'] -class Client: - """ - This is a main `pyignite` class, that is build upon the - :class:`~pyignite.connection.Connection`. In addition to the attributes, - properties and methods of its parent class, `Client` implements - the following features: - - * cache factory. Cache objects are used for key-value operations, - * Ignite SQL endpoint, - * binary types registration endpoint. - """ - - _registry = defaultdict(dict) - _compact_footer: bool = None - _connection_args: Dict = None - _current_node: int = None - _nodes: List[Connection] = None - +class BaseClient: # used for Complex object data class names sanitizing _identifier = re.compile(r'[^0-9a-zA-Z_.+$]', re.UNICODE) _ident_start = re.compile(r'^[^a-zA-Z_]+', re.UNICODE) - affinity_version: Optional[Tuple] = None - protocol_version: Optional[Tuple] = None - - def __init__( - self, compact_footer: bool = None, partition_aware: bool = False, - **kwargs - ): - """ - Initialize client. - - :param compact_footer: (optional) use compact (True, recommended) or - full (False) schema approach when serializing Complex objects. - Default is to use the same approach the server is using (None). - Apache Ignite binary protocol documentation on this topic: - https://apacheignite.readme.io/docs/binary-client-protocol-data-format#section-schema - :param partition_aware: (optional) try to calculate the exact data - placement from the key before to issue the key operation to the - server node: - https://cwiki.apache.org/confluence/display/IGNITE/IEP-23%3A+Best+Effort+Affinity+for+thin+clients - The feature is in experimental status, so the parameter is `False` - by default. This will be changed later. - """ + def __init__(self, compact_footer: bool = None, partition_aware: bool = False, **kwargs): self._compact_footer = compact_footer + self._partition_aware = partition_aware self._connection_args = kwargs + self._registry = defaultdict(dict) self._nodes = [] self._current_node = 0 self._partition_aware = partition_aware self.affinity_version = (0, 0) + self._protocol_version = None - def get_protocol_version(self) -> Optional[Tuple]: + @property + def protocol_version(self): """ Returns the tuple of major, minor, and revision numbers of the used thin protocol version, or None, if no connection to the Ignite cluster @@ -128,7 +92,11 @@ def get_protocol_version(self) -> Optional[Tuple]: extend the `pyignite` capabilities (with additional testing, logging, examining connections, et c.) you probably should not use it. """ - return self.protocol_version + return self._protocol_version + + @protocol_version.setter + def protocol_version(self, value): + self._protocol_version = value @property def partition_aware(self): @@ -136,32 +104,182 @@ def partition_aware(self): @property def partition_awareness_supported_by_protocol(self): - # TODO: Need to re-factor this. I believe, we need separate class or - # set of functions to work with protocol versions without manually - # comparing versions with just some random tuples return self.protocol_version is not None and self.protocol_version >= (1, 4, 0) - def connect(self, *args): + @property + def compact_footer(self) -> bool: """ - Connect to Ignite cluster node(s). + This property remembers Complex object schema encoding approach when + decoding any Complex object, to use the same approach on Complex + object encoding. - :param args: (optional) host(s) and port(s) to connect to. + :return: True if compact schema was used by server or no Complex + object decoding has yet taken place, False if full schema was used. """ + # this is an ordinary object property, but its backing storage + # is a class attribute + + # use compact schema by default, but leave initial (falsy) backing + # value unchanged + return self._compact_footer or self._compact_footer is None + + @compact_footer.setter + def compact_footer(self, value: bool): + # normally schema approach should not change + if self._compact_footer not in (value, None): + raise Warning('Can not change client schema approach.') + else: + self._compact_footer = value + + @staticmethod + def _process_connect_args(*args): if len(args) == 0: # no parameters − use default Ignite host and port - nodes = [(IGNITE_DEFAULT_HOST, IGNITE_DEFAULT_PORT)] - elif len(args) == 1 and is_iterable(args[0]): + return [(IGNITE_DEFAULT_HOST, IGNITE_DEFAULT_PORT)] + if len(args) == 1 and is_iterable(args[0]): # iterable of host-port pairs is given - nodes = args[0] - elif ( - len(args) == 2 - and isinstance(args[0], str) - and isinstance(args[1], int) - ): + return args[0] + if len(args) == 2 and isinstance(args[0], str) and isinstance(args[1], int): # host and port are given - nodes = [args] - else: - raise ConnectionError('Connection parameters are not valid.') + return [args] + + raise ConnectionError('Connection parameters are not valid.') + + def _process_get_binary_type_result(self, result): + if result.status != 0 or not result.value['type_exists']: + return result + + binary_fields = result.value.pop('binary_fields') + old_format_schemas = result.value.pop('schema') + result.value['schemas'] = [] + for s_id, field_ids in old_format_schemas.items(): + result.value['schemas'].append(self._convert_schema(field_ids, binary_fields)) + return result + + @staticmethod + def _convert_type(tc_type: int): + try: + return tc_map(tc_type.to_bytes(1, PROTOCOL_BYTE_ORDER)) + except (KeyError, OverflowError): + # if conversion to char or type lookup failed, + # we probably have a binary object type ID + return BinaryObject + + def _convert_schema(self, field_ids: list, binary_fields: list) -> OrderedDict: + converted_schema = OrderedDict() + for field_id in field_ids: + binary_field = next(x for x in binary_fields if x['field_id'] == field_id) + converted_schema[binary_field['field_name']] = self._convert_type(binary_field['type_id']) + return converted_schema + + @staticmethod + def _create_dataclass(type_name: str, schema: OrderedDict = None) -> Type: + """ + Creates default (generic) class for Ignite Complex object. + + :param type_name: Complex object type name, + :param schema: Complex object schema, + :return: the resulting class. + """ + schema = schema or {} + return GenericObjectMeta(type_name, (), {}, schema=schema) + + @classmethod + def _create_type_name(cls, type_name: str) -> str: + """ + Creates Python data class name from Ignite binary type name. + + Handles all the special cases found in + `java.org.apache.ignite.binary.BinaryBasicNameMapper.simpleName()`. + Tries to adhere to PEP8 along the way. + """ + + # general sanitizing + type_name = cls._identifier.sub('', type_name) + + # - name ending with '$' (Scala) + # - name + '$' + some digits (anonymous class) + # - '$$Lambda$' in the middle + type_name = process_delimiter(type_name, '$') + + # .NET outer/inner class delimiter + type_name = process_delimiter(type_name, '+') + + # Java fully qualified class name + type_name = process_delimiter(type_name, '.') + + # start chars sanitizing + type_name = capitalize(cls._ident_start.sub('', type_name)) + + return type_name + + def _sync_binary_registry(self, type_id: int, type_info: dict): + """ + Sync binary registry + :param type_id: Complex object type ID. + :param type_info: Complex object type info. + """ + if type_info['type_exists']: + for schema in type_info['schemas']: + if not self._registry[type_id].get(schema_id(schema), None): + data_class = self._create_dataclass( + self._create_type_name(type_info['type_name']), + schema, + ) + self._registry[type_id][schema_id(schema)] = data_class + + def _get_from_registry(self, type_id, schema): + """ + Get binary type info from registry. + + :param type_id: Complex object type ID. + :param schema: Complex object schema. + """ + if schema: + try: + return self._registry[type_id][schema_id(schema)] + except KeyError: + return None + return self._registry[type_id] + + +class Client(BaseClient): + """ + This is a main `pyignite` class, that is build upon the + :class:`~pyignite.connection.Connection`. In addition to the attributes, + properties and methods of its parent class, `Client` implements + the following features: + + * cache factory. Cache objects are used for key-value operations, + * Ignite SQL endpoint, + * binary types registration endpoint. + """ + + def __init__(self, compact_footer: bool = None, partition_aware: bool = False, **kwargs): + """ + Initialize client. + + :param compact_footer: (optional) use compact (True, recommended) or + full (False) schema approach when serializing Complex objects. + Default is to use the same approach the server is using (None). + Apache Ignite binary protocol documentation on this topic: + https://apacheignite.readme.io/docs/binary-client-protocol-data-format#section-schema + :param partition_aware: (optional) try to calculate the exact data + placement from the key before to issue the key operation to the + server node: + https://cwiki.apache.org/confluence/display/IGNITE/IEP-23%3A+Best+Effort+Affinity+for+thin+clients + The feature is in experimental status, so the parameter is `False` + by default. This will be changed later. + """ + super().__init__(compact_footer, partition_aware, **kwargs) + + def connect(self, *args): + """ + Connect to Ignite cluster node(s). + + :param args: (optional) host(s) and port(s) to connect to. + """ + nodes = self._process_connect_args(*args) # the following code is quite twisted, because the protocol version # is initially unknown @@ -169,14 +287,12 @@ def connect(self, *args): # TODO: open first node in foreground, others − in background for i, node in enumerate(nodes): host, port = node - conn = Connection(self, **self._connection_args) - conn.host = host - conn.port = port + conn = Connection(self, host, port, **self._connection_args) try: if self.protocol_version is None or self.partition_aware: # open connection before adding to the pool - conn.connect(host, port) + conn.connect() # now we have the protocol version if not self.partition_aware: @@ -210,13 +326,7 @@ def random_node(self) -> Connection: """ if self.partition_aware: # if partition awareness is used just pick a random connected node - try: - return random.choice( - list(n for n in self._nodes if n.alive) - ) - except IndexError: - # cannot choose from an empty sequence - raise ReconnectError('Can not reconnect: out of nodes.') from None + return self._get_random_node() else: # if partition awareness is not used then just return the current # node if it's alive or the next usable node if connection with the @@ -238,7 +348,7 @@ def random_node(self) -> Connection: for i in chain(range(self._current_node, num_nodes), range(self._current_node)): node = self._nodes[i] try: - node.connect(node.host, node.port) + node.connect() except connection_errors: pass else: @@ -247,6 +357,19 @@ def random_node(self) -> Connection: # no nodes left raise ReconnectError('Can not reconnect: out of nodes.') + def _get_random_node(self, reconnect=True): + alive_nodes = [n for n in self._nodes if n.alive] + if alive_nodes: + return random.choice(alive_nodes) + elif reconnect: + for n in self._nodes: + n.reconnect() + + return self._get_random_node(reconnect=False) + else: + # cannot choose from an empty sequence + raise ReconnectError('Can not reconnect: out of nodes.') from None + @status_to_exception(BinaryTypeError) def get_binary_type(self, binary_type: Union[str, int]) -> dict: """ @@ -267,71 +390,8 @@ def get_binary_type(self, binary_type: Union[str, int]) -> dict: - `schemas`: a list, containing the Complex object schemas in format: OrderedDict[field name: field type hint]. A schema can be empty. """ - def convert_type(tc_type: int): - try: - return tc_map(tc_type.to_bytes(1, PROTOCOL_BYTE_ORDER)) - except (KeyError, OverflowError): - # if conversion to char or type lookup failed, - # we probably have a binary object type ID - return BinaryObject - - def convert_schema( - field_ids: list, binary_fields: list - ) -> OrderedDict: - converted_schema = OrderedDict() - for field_id in field_ids: - binary_field = [ - x - for x in binary_fields - if x['field_id'] == field_id - ][0] - converted_schema[binary_field['field_name']] = convert_type( - binary_field['type_id'] - ) - return converted_schema - - conn = self.random_node - - result = get_binary_type(conn, binary_type) - if result.status != 0 or not result.value['type_exists']: - return result - - binary_fields = result.value.pop('binary_fields') - old_format_schemas = result.value.pop('schema') - result.value['schemas'] = [] - for s_id, field_ids in old_format_schemas.items(): - result.value['schemas'].append( - convert_schema(field_ids, binary_fields) - ) - return result - - @property - def compact_footer(self) -> bool: - """ - This property remembers Complex object schema encoding approach when - decoding any Complex object, to use the same approach on Complex - object encoding. - - :return: True if compact schema was used by server or no Complex - object decoding has yet taken place, False if full schema was used. - """ - # this is an ordinary object property, but its backing storage - # is a class attribute - - # use compact schema by default, but leave initial (falsy) backing - # value unchanged - return ( - self.__class__._compact_footer - or self.__class__._compact_footer is None - ) - - @compact_footer.setter - def compact_footer(self, value: bool): - # normally schema approach should not change - if self.__class__._compact_footer not in (value, None): - raise Warning('Can not change client schema approach.') - else: - self.__class__._compact_footer = value + result = get_binary_type(self.random_node, binary_type) + return self._process_get_binary_type_result(result) @status_to_exception(BinaryTypeError) def put_binary_type( @@ -353,71 +413,9 @@ def put_binary_type( When register binary type, pass a dict of field names: field types. Binary type with no fields is OK. """ - return put_binary_type( - self.random_node, type_name, affinity_key_field, is_enum, schema - ) + return put_binary_type(self.random_node, type_name, affinity_key_field, is_enum, schema) - @staticmethod - def _create_dataclass(type_name: str, schema: OrderedDict = None) -> Type: - """ - Creates default (generic) class for Ignite Complex object. - - :param type_name: Complex object type name, - :param schema: Complex object schema, - :return: the resulting class. - """ - schema = schema or {} - return GenericObjectMeta(type_name, (), {}, schema=schema) - - def _sync_binary_registry(self, type_id: int): - """ - Reads Complex object description from Ignite server. Creates default - Complex object classes and puts in registry, if not already there. - - :param type_id: Complex object type ID. - """ - type_info = self.get_binary_type(type_id) - if type_info['type_exists']: - for schema in type_info['schemas']: - if not self._registry[type_id].get(schema_id(schema), None): - data_class = self._create_dataclass( - self._create_type_name(type_info['type_name']), - schema, - ) - self._registry[type_id][schema_id(schema)] = data_class - - @classmethod - def _create_type_name(cls, type_name: str) -> str: - """ - Creates Python data class name from Ignite binary type name. - - Handles all the special cases found in - `java.org.apache.ignite.binary.BinaryBasicNameMapper.simpleName()`. - Tries to adhere to PEP8 along the way. - """ - - # general sanitizing - type_name = cls._identifier.sub('', type_name) - - # - name ending with '$' (Scala) - # - name + '$' + some digits (anonymous class) - # - '$$Lambda$' in the middle - type_name = process_delimiter(type_name, '$') - - # .NET outer/inner class delimiter - type_name = process_delimiter(type_name, '+') - - # Java fully qualified class name - type_name = process_delimiter(type_name, '.') - - # start chars sanitizing - type_name = capitalize(cls._ident_start.sub('', type_name)) - - return type_name - - def register_binary_type( - self, data_class: Type, affinity_key_field: str = None, - ): + def register_binary_type(self, data_class: Type, affinity_key_field: str = None): """ Register the given class as a representation of a certain Complex object type. Discards autogenerated or previously registered class. @@ -425,47 +423,44 @@ def register_binary_type( :param data_class: Complex object class, :param affinity_key_field: (optional) affinity parameter. """ - if not self.query_binary_type( - data_class.type_id, data_class.schema_id - ): - self.put_binary_type( - data_class.type_name, - affinity_key_field, - schema=data_class.schema, - ) + if not self.query_binary_type(data_class.type_id, data_class.schema_id): + self.put_binary_type(data_class.type_name, affinity_key_field, schema=data_class.schema) self._registry[data_class.type_id][data_class.schema_id] = data_class - def query_binary_type( - self, binary_type: Union[int, str], schema: Union[int, dict] = None, - sync: bool = True - ): + def query_binary_type(self, binary_type: Union[int, str], schema: Union[int, dict] = None): """ Queries the registry of Complex object classes. :param binary_type: Complex object type name or ID, - :param schema: (optional) Complex object schema or schema ID, - :param sync: (optional) look up the Ignite server for registered - Complex objects and create data classes for them if needed, + :param schema: (optional) Complex object schema or schema ID :return: found dataclass or None, if `schema` parameter is provided, a dict of {schema ID: dataclass} format otherwise. """ type_id = entity_id(binary_type) - s_id = schema_id(schema) - - if schema: - try: - result = self._registry[type_id][s_id] - except KeyError: - result = None - else: - result = self._registry[type_id] - if sync and not result: - self._sync_binary_registry(type_id) - return self.query_binary_type(type_id, s_id, sync=False) + result = self._get_from_registry(type_id, schema) + if not result: + type_info = self.get_binary_type(type_id) + self._sync_binary_registry(type_id, type_info) + return self._get_from_registry(type_id, schema) return result + def unwrap_binary(self, value: Any) -> Any: + """ + Detects and recursively unwraps Binary Object. + + :param value: anything that could be a Binary Object, + :return: the result of the Binary Object unwrapping with all other data + left intact. + """ + if is_wrapped(value): + blob, offset = value + with BinaryStream(self, blob) as stream: + data_class = BinaryObject.parse(stream) + return BinaryObject.to_python(stream.read_ctype(data_class, direction=READ_BACKWARD), self) + return value + def create_cache(self, settings: Union[str, dict]) -> 'Cache': """ Creates Ignite cache by name. Raises `CacheError` if such a cache is @@ -477,7 +472,7 @@ def create_cache(self, settings: Union[str, dict]) -> 'Cache': :ref:`cache creation example `, :return: :class:`~pyignite.cache.Cache` object. """ - return Cache(self, settings) + return create_cache(self, settings) def get_or_create_cache(self, settings: Union[str, dict]) -> 'Cache': """ @@ -489,7 +484,7 @@ def get_or_create_cache(self, settings: Union[str, dict]) -> 'Cache': :ref:`cache creation example `, :return: :class:`~pyignite.cache.Cache` object. """ - return Cache(self, settings, with_get=True) + return get_or_create_cache(self, settings) def get_cache(self, settings: Union[str, dict]) -> 'Cache': """ @@ -501,7 +496,7 @@ def get_cache(self, settings: Union[str, dict]) -> 'Cache': property is allowed), :return: :class:`~pyignite.cache.Cache` object. """ - return Cache(self, settings, get_only=True) + return get_cache(self, settings) @status_to_exception(CacheError) def get_cache_names(self) -> list: @@ -559,42 +554,12 @@ def sql( :return: generator with result rows as a lists. If `include_field_names` was set, the first row will hold field names. """ - def generate_result(value): - cursor = value['cursor'] - more = value['more'] - - if include_field_names: - yield value['fields'] - field_count = len(value['fields']) - else: - field_count = value['field_count'] - for line in value['data']: - yield line - - while more: - inner_result = sql_fields_cursor_get_page( - conn, cursor, field_count - ) - if inner_result.status != 0: - raise SQLError(result.message) - more = inner_result.value['more'] - for line in inner_result.value['data']: - yield line - - conn = self.random_node c_id = cache.cache_id if isinstance(cache, Cache) else cache_id(cache) if c_id != 0: schema = None - result = sql_fields( - conn, c_id, query_str, page_size, query_args, schema, - statement_type, distributed_joins, local, replicated_only, - enforce_join_order, collocated, lazy, include_field_names, - max_rows, timeout, - ) - if result.status != 0: - raise SQLError(result.message) - - return generate_result(result.value) + return SqlFieldsCursor(self, c_id, query_str, page_size, query_args, schema, statement_type, distributed_joins, + local, replicated_only, enforce_join_order, collocated, lazy, include_field_names, + max_rows, timeout) diff --git a/pyignite/connection/__init__.py b/pyignite/connection/__init__.py index 1114594..14e820a 100644 --- a/pyignite/connection/__init__.py +++ b/pyignite/connection/__init__.py @@ -34,5 +34,6 @@ """ from .connection import Connection +from .aio_connection import AioConnection -__all__ = ['Connection'] +__all__ = ['Connection', 'AioConnection'] diff --git a/pyignite/connection/aio_connection.py b/pyignite/connection/aio_connection.py new file mode 100644 index 0000000..e5c11da --- /dev/null +++ b/pyignite/connection/aio_connection.py @@ -0,0 +1,242 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from asyncio import Lock +from collections import OrderedDict +from io import BytesIO +from typing import Union + +from pyignite.constants import PROTOCOLS, PROTOCOL_BYTE_ORDER +from pyignite.exceptions import HandshakeError, SocketError, connection_errors +from .connection import BaseConnection + +from .handshake import HandshakeRequest, HandshakeResponse +from .ssl import create_ssl_context +from ..stream import AioBinaryStream + + +class AioConnection(BaseConnection): + """ + Asyncio connection to Ignite node. It serves multiple purposes: + + * wrapper of asyncio streams. See also https://docs.python.org/3/library/asyncio-stream.html + * encapsulates handshake and reconnection. + """ + + def __init__(self, client: 'AioClient', host: str, port: int, username: str = None, password: str = None, + **ssl_params): + """ + Initialize connection. + + For the use of the SSL-related parameters see + https://docs.python.org/3/library/ssl.html#ssl-certificates. + + :param client: Ignite client object, + :param host: Ignite server node's host name or IP, + :param port: Ignite server node's port number, + :param use_ssl: (optional) set to True if Ignite server uses SSL + on its binary connector. Defaults to use SSL when username + and password has been supplied, not to use SSL otherwise, + :param ssl_version: (optional) SSL version constant from standard + `ssl` module. Defaults to TLS v1.1, as in Ignite 2.5, + :param ssl_ciphers: (optional) ciphers to use. If not provided, + `ssl` default ciphers are used, + :param ssl_cert_reqs: (optional) determines how the remote side + certificate is treated: + + * `ssl.CERT_NONE` − remote certificate is ignored (default), + * `ssl.CERT_OPTIONAL` − remote certificate will be validated, + if provided, + * `ssl.CERT_REQUIRED` − valid remote certificate is required, + + :param ssl_keyfile: (optional) a path to SSL key file to identify + local (client) party, + :param ssl_keyfile_password: (optional) password for SSL key file, + can be provided when key file is encrypted to prevent OpenSSL + password prompt, + :param ssl_certfile: (optional) a path to ssl certificate file + to identify local (client) party, + :param ssl_ca_certfile: (optional) a path to a trusted certificate + or a certificate chain. Required to check the validity of the remote + (server-side) certificate, + :param username: (optional) user name to authenticate to Ignite + cluster, + :param password: (optional) password to authenticate to Ignite cluster. + """ + super().__init__(client, host, port, username, password, **ssl_params) + self._mux = Lock() + self._reader = None + self._writer = None + + @property + def closed(self) -> bool: + """ Tells if socket is closed. """ + return self._writer is None + + async def connect(self) -> Union[dict, OrderedDict]: + """ + Connect to the given server node with protocol version fallback. + """ + async with self._mux: + return await self._connect() + + async def _connect(self) -> Union[dict, OrderedDict]: + detecting_protocol = False + + # choose highest version first + if self.client.protocol_version is None: + detecting_protocol = True + self.client.protocol_version = max(PROTOCOLS) + + try: + result = await self._connect_version() + except HandshakeError as e: + if e.expected_version in PROTOCOLS: + self.client.protocol_version = e.expected_version + result = await self._connect_version() + else: + raise e + except connection_errors: + # restore undefined protocol version + if detecting_protocol: + self.client.protocol_version = None + raise + + # connection is ready for end user + self.uuid = result.get('node_uuid', None) # version-specific (1.4+) + + self.failed = False + return result + + async def _connect_version(self) -> Union[dict, OrderedDict]: + """ + Connect to the given server node using protocol version + defined on client. + """ + + ssl_context = create_ssl_context(self.ssl_params) + self._reader, self._writer = await asyncio.open_connection(self.host, self.port, ssl=ssl_context) + + protocol_version = self.client.protocol_version + + hs_request = HandshakeRequest( + protocol_version, + self.username, + self.password + ) + + with AioBinaryStream(self.client) as stream: + await hs_request.from_python_async(stream) + await self._send(stream.getbuffer(), reconnect=False) + + with AioBinaryStream(self.client, await self._recv(reconnect=False)) as stream: + hs_response = await HandshakeResponse.parse_async(stream, self.protocol_version) + + if hs_response.op_code == 0: + self._close() + self._process_handshake_error(hs_response) + + return hs_response + + async def reconnect(self): + async with self._mux: + await self._reconnect() + + async def _reconnect(self): + if self.alive: + return + + self._close() + + # connect and silence the connection errors + try: + await self._connect() + except connection_errors: + pass + + async def request(self, data: Union[bytes, bytearray, memoryview]) -> bytearray: + """ + Perform request. + + :param data: bytes to send. + """ + async with self._mux: + await self._send(data) + return await self._recv() + + async def _send(self, data: Union[bytes, bytearray, memoryview], reconnect=True): + if self.closed: + raise SocketError('Attempt to use closed connection.') + + try: + self._writer.write(data) + await self._writer.drain() + except connection_errors: + self.failed = True + if reconnect: + await self._reconnect() + raise + + async def _recv(self, reconnect=True) -> bytearray: + if self.closed: + raise SocketError('Attempt to use closed connection.') + + with BytesIO() as stream: + try: + buf = await self._reader.readexactly(4) + response_len = int.from_bytes(buf, PROTOCOL_BYTE_ORDER) + + stream.write(buf) + + stream.write(await self._reader.readexactly(response_len)) + except connection_errors: + self.failed = True + if reconnect: + await self._reconnect() + raise + + return bytearray(stream.getbuffer()) + + async def close(self): + async with self._mux: + self._close() + + def _close(self): + """ + Close connection. + """ + if self._writer: + try: + self._writer.close() + except connection_errors: + pass + + self._writer, self._reader = None, None diff --git a/pyignite/connection/connection.py b/pyignite/connection/connection.py index 8db304e..901cb56 100644 --- a/pyignite/connection/connection.py +++ b/pyignite/connection/connection.py @@ -32,64 +32,94 @@ import socket from typing import Union -from pyignite.constants import * -from pyignite.exceptions import ( - HandshakeError, ParameterError, SocketError, connection_errors, AuthenticationError, -) -from pyignite.datatypes import Byte, Int, Short, String, UUIDObject -from pyignite.datatypes.internal import Struct +from pyignite.constants import PROTOCOLS, IGNITE_DEFAULT_HOST, IGNITE_DEFAULT_PORT, PROTOCOL_BYTE_ORDER +from pyignite.exceptions import HandshakeError, SocketError, connection_errors, AuthenticationError -from .handshake import HandshakeRequest -from .ssl import wrap -from ..stream import BinaryStream, READ_BACKWARD +from .handshake import HandshakeRequest, HandshakeResponse +from .ssl import wrap, check_ssl_params +from ..stream import BinaryStream CLIENT_STATUS_AUTH_FAILURE = 2000 -class Connection: +class BaseConnection: + def __init__(self, client, host: str = None, port: int = None, username: str = None, password: str = None, + **ssl_params): + self.client = client + self.host = host if host else IGNITE_DEFAULT_HOST + self.port = port if port else IGNITE_DEFAULT_PORT + self.username = username + self.password = password + self.uuid = None + + check_ssl_params(ssl_params) + + if self.username and self.password and 'use_ssl' not in ssl_params: + ssl_params['use_ssl'] = True + + self.ssl_params = ssl_params + self._failed = False + + @property + def closed(self) -> bool: + """ Tells if socket is closed. """ + raise NotImplementedError + + @property + def failed(self) -> bool: + """ Tells if connection is failed. """ + return self._failed + + @failed.setter + def failed(self, value): + self._failed = value + + @property + def alive(self) -> bool: + """ Tells if connection is up and no failure detected. """ + return not self.failed and not self.closed + + def __repr__(self) -> str: + return '{}:{}'.format(self.host or '?', self.port or '?') + + @property + def protocol_version(self): + """ + Returns the tuple of major, minor, and revision numbers of the used + thin protocol version, or None, if no connection to the Ignite cluster + was yet established. + """ + return self.client.protocol_version + + def _process_handshake_error(self, response): + error_text = f'Handshake error: {response.message}' + # if handshake fails for any reason other than protocol mismatch + # (i.e. authentication error), server version is 0.0.0 + protocol_version = self.client.protocol_version + server_version = (response.version_major, response.version_minor, response.version_patch) + + if any(server_version): + error_text += f' Server expects binary protocol version ' \ + f'{server_version[0]}.{server_version[1]}.{server_version[2]}. ' \ + f'Client provides ' \ + f'{protocol_version[0]}.{protocol_version[1]}.{protocol_version[2]}.' + elif response.client_status == CLIENT_STATUS_AUTH_FAILURE: + raise AuthenticationError(error_text) + raise HandshakeError(server_version, error_text) + + +class Connection(BaseConnection): """ This is a `pyignite` class, that represents a connection to Ignite node. It serves multiple purposes: * socket wrapper. Detects fragmentation and network errors. See also https://docs.python.org/3/howto/sockets.html, - * binary protocol connector. Incapsulates handshake and failover reconnection. + * binary protocol connector. Encapsulates handshake and failover reconnection. """ - _socket = None - _failed = None - - client = None - host = None - port = None - timeout = None - username = None - password = None - ssl_params = {} - uuid = None - - @staticmethod - def _check_ssl_params(params): - expected_args = [ - 'use_ssl', - 'ssl_version', - 'ssl_ciphers', - 'ssl_cert_reqs', - 'ssl_keyfile', - 'ssl_keyfile_password', - 'ssl_certfile', - 'ssl_ca_certfile', - ] - for param in params: - if param not in expected_args: - raise ParameterError(( - 'Unexpected parameter for connection initialization: `{}`' - ).format(param)) - - def __init__( - self, client: 'Client', timeout: float = 2.0, - username: str = None, password: str = None, **ssl_params - ): + def __init__(self, client: 'Client', host: str, port: int, timeout: float = 2.0, + username: str = None, password: str = None, **ssl_params): """ Initialize connection. @@ -97,6 +127,8 @@ def __init__( https://docs.python.org/3/library/ssl.html#ssl-certificates. :param client: Ignite client object, + :param host: Ignite server node's host name or IP, + :param port: Ignite server node's port number, :param timeout: (optional) sets timeout (in seconds) for each socket operation including `connect`. 0 means non-blocking mode, which is virtually guaranteed to fail. Can accept integer or float value. @@ -130,84 +162,15 @@ def __init__( cluster, :param password: (optional) password to authenticate to Ignite cluster. """ - self.client = client + super().__init__(client, host, port, username, password, **ssl_params) self.timeout = timeout - self.username = username - self.password = password - self._check_ssl_params(ssl_params) - if self.username and self.password and 'use_ssl' not in ssl_params: - ssl_params['use_ssl'] = True - self.ssl_params = ssl_params - self._failed = False + self._socket = None @property def closed(self) -> bool: - """ Tells if socket is closed. """ return self._socket is None - @property - def failed(self) -> bool: - """ Tells if connection is failed. """ - return self._failed - - @failed.setter - def failed(self, value): - self._failed = value - - @property - def alive(self) -> bool: - """ Tells if connection is up and no failure detected. """ - return not self.failed and not self.closed - - def __repr__(self) -> str: - return '{}:{}'.format(self.host or '?', self.port or '?') - - _wrap = wrap - - def get_protocol_version(self): - """ - Returns the tuple of major, minor, and revision numbers of the used - thin protocol version, or None, if no connection to the Ignite cluster - was yet established. - """ - return self.client.protocol_version - - def read_response(self) -> Union[dict, OrderedDict]: - """ - Processes server's response to the handshake request. - - :return: handshake data. - """ - response_start = Struct([ - ('length', Int), - ('op_code', Byte), - ]) - with BinaryStream(self, self.recv(reconnect=False)) as stream: - start_class = response_start.parse(stream) - start = stream.read_ctype(start_class, direction=READ_BACKWARD) - data = response_start.to_python(start) - response_end = None - if data['op_code'] == 0: - response_end = Struct([ - ('version_major', Short), - ('version_minor', Short), - ('version_patch', Short), - ('message', String), - ('client_status', Int) - ]) - elif self.get_protocol_version() >= (1, 4, 0): - response_end = Struct([ - ('node_uuid', UUIDObject), - ]) - if response_end: - end_class = response_end.parse(stream) - end = stream.read_ctype(end_class, direction=READ_BACKWARD) - data.update(response_end.to_python(end)) - return data - - def connect( - self, host: str = None, port: int = None - ) -> Union[dict, OrderedDict]: + def connect(self) -> Union[dict, OrderedDict]: """ Connect to the given server node with protocol version fallback. @@ -222,11 +185,11 @@ def connect( self.client.protocol_version = max(PROTOCOLS) try: - result = self._connect_version(host, port) + result = self._connect_version() except HandshakeError as e: if e.expected_version in PROTOCOLS: self.client.protocol_version = e.expected_version - result = self._connect_version(host, port) + result = self._connect_version() else: raise e except connection_errors: @@ -237,28 +200,19 @@ def connect( # connection is ready for end user self.uuid = result.get('node_uuid', None) # version-specific (1.4+) - self.failed = False return result - def _connect_version( - self, host: str = None, port: int = None, - ) -> Union[dict, OrderedDict]: + def _connect_version(self) -> Union[dict, OrderedDict]: """ Connect to the given server node using protocol version defined on client. - - :param host: Ignite server node's host name or IP, - :param port: Ignite server node's port number. """ - host = host or IGNITE_DEFAULT_HOST - port = port or IGNITE_DEFAULT_PORT - self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._socket.settimeout(self.timeout) - self._socket = self._wrap(self._socket) - self._socket.connect((host, port)) + self._socket = wrap(self._socket, self.ssl_params) + self._socket.connect((self.host, self.port)) protocol_version = self.client.protocol_version @@ -268,56 +222,41 @@ def _connect_version( self.password ) - with BinaryStream(self) as stream: + with BinaryStream(self.client) as stream: hs_request.from_python(stream) self.send(stream.getbuffer(), reconnect=False) - hs_response = self.read_response() - if hs_response['op_code'] == 0: - self.close() - - error_text = 'Handshake error: {}'.format(hs_response['message']) - # if handshake fails for any reason other than protocol mismatch - # (i.e. authentication error), server version is 0.0.0 - if any([ - hs_response['version_major'], - hs_response['version_minor'], - hs_response['version_patch'], - ]): - error_text += ( - ' Server expects binary protocol version ' - '{version_major}.{version_minor}.{version_patch}. Client ' - 'provides {client_major}.{client_minor}.{client_patch}.' - ).format( - client_major=protocol_version[0], - client_minor=protocol_version[1], - client_patch=protocol_version[2], - **hs_response - ) - elif hs_response['client_status'] == CLIENT_STATUS_AUTH_FAILURE: - raise AuthenticationError(error_text) - raise HandshakeError(( - hs_response['version_major'], - hs_response['version_minor'], - hs_response['version_patch'], - ), error_text) - self.host, self.port = host, port - return hs_response + with BinaryStream(self.client, self.recv(reconnect=False)) as stream: + hs_response = HandshakeResponse.parse(stream, self.protocol_version) + + if hs_response.op_code == 0: + self.close() + self._process_handshake_error(hs_response) + + return hs_response def reconnect(self): - # do not reconnect if connection is already working - # or was closed on purpose - if not self.failed: + if self.alive: return self.close() # connect and silence the connection errors try: - self.connect(self.host, self.port) + self.connect() except connection_errors: pass + def request(self, data: Union[bytes, bytearray, memoryview], flags=None) -> bytearray: + """ + Perform request. + + :param data: bytes to send, + :param flags: (optional) OS-specific flags. + """ + self.send(data, flags=flags) + return self.recv() + def send(self, data: Union[bytes, bytearray, memoryview], flags=None, reconnect=True): """ Send data down the socket. @@ -337,7 +276,8 @@ def send(self, data: Union[bytes, bytearray, memoryview], flags=None, reconnect= self._socket.sendall(data, **kwargs) except connection_errors: self.failed = True - self.reconnect() + if reconnect: + self.reconnect() raise def recv(self, flags=None, reconnect=True) -> bytearray: diff --git a/pyignite/connection/handshake.py b/pyignite/connection/handshake.py index 3315c4e..0b0fe50 100644 --- a/pyignite/connection/handshake.py +++ b/pyignite/connection/handshake.py @@ -15,8 +15,9 @@ from typing import Optional, Tuple -from pyignite.datatypes import Byte, Int, Short, String +from pyignite.datatypes import Byte, Int, Short, String, UUIDObject from pyignite.datatypes.internal import Struct +from pyignite.stream import READ_BACKWARD OP_HANDSHAKE = 1 @@ -51,6 +52,12 @@ def __init__( self.handshake_struct = Struct(fields) def from_python(self, stream): + self.handshake_struct.from_python(stream, self.__create_handshake_data()) + + async def from_python_async(self, stream): + await self.handshake_struct.from_python_async(stream, self.__create_handshake_data()) + + def __create_handshake_data(self): handshake_data = { 'length': 8, 'op_code': OP_HANDSHAKE, @@ -69,5 +76,66 @@ def from_python(self, stream): len(self.username), len(self.password), ]) + return handshake_data + + +class HandshakeResponse(dict): + """ + Handshake response. + """ + __response_start = Struct([ + ('length', Int), + ('op_code', Byte), + ]) + + def __init__(self, data): + super().__init__() + self.update(data) + + def __getattr__(self, item): + return self.get(item) + + @classmethod + def parse(cls, stream, protocol_version): + start_class = cls.__response_start.parse(stream) + start = stream.read_ctype(start_class, direction=READ_BACKWARD) + data = cls.__response_start.to_python(start) - self.handshake_struct.from_python(stream, handshake_data) + response_end = cls.__create_response_end(data, protocol_version) + if response_end: + end_class = response_end.parse(stream) + end = stream.read_ctype(end_class, direction=READ_BACKWARD) + data.update(response_end.to_python(end)) + + return cls(data) + + @classmethod + async def parse_async(cls, stream, protocol_version): + start_class = cls.__response_start.parse(stream) + start = stream.read_ctype(start_class, direction=READ_BACKWARD) + data = await cls.__response_start.to_python_async(start) + + response_end = cls.__create_response_end(data, protocol_version) + if response_end: + end_class = await response_end.parse_async(stream) + end = stream.read_ctype(end_class, direction=READ_BACKWARD) + data.update(await response_end.to_python_async(end)) + + return cls(data) + + @classmethod + def __create_response_end(cls, start_data, protocol_version): + response_end = None + if start_data['op_code'] == 0: + response_end = Struct([ + ('version_major', Short), + ('version_minor', Short), + ('version_patch', Short), + ('message', String), + ('client_status', Int) + ]) + elif protocol_version >= (1, 4, 0): + response_end = Struct([ + ('node_uuid', UUIDObject), + ]) + return response_end diff --git a/pyignite/connection/ssl.py b/pyignite/connection/ssl.py index 9773860..385b414 100644 --- a/pyignite/connection/ssl.py +++ b/pyignite/connection/ssl.py @@ -16,34 +16,62 @@ import ssl from ssl import SSLContext -from pyignite.constants import * +from pyignite.constants import SSL_DEFAULT_CIPHERS, SSL_DEFAULT_VERSION +from pyignite.exceptions import ParameterError -def wrap(conn: 'Connection', _socket): +def wrap(socket, ssl_params): """ Wrap socket in SSL wrapper. """ - if conn.ssl_params.get('use_ssl', None): - keyfile = conn.ssl_params.get('ssl_keyfile', None) - certfile = conn.ssl_params.get('ssl_certfile', None) + if not ssl_params.get('use_ssl'): + return socket - if keyfile and not certfile: - raise ValueError("certfile must be specified") + context = create_ssl_context(ssl_params) - password = conn.ssl_params.get('ssl_keyfile_password', None) - ssl_version = conn.ssl_params.get('ssl_version', SSL_DEFAULT_VERSION) - ciphers = conn.ssl_params.get('ssl_ciphers', SSL_DEFAULT_CIPHERS) - cert_reqs = conn.ssl_params.get('ssl_cert_reqs', ssl.CERT_NONE) - ca_certs = conn.ssl_params.get('ssl_ca_certfile', None) + return context.wrap_socket(sock=socket) - context = SSLContext(ssl_version) - context.verify_mode = cert_reqs - if ca_certs: - context.load_verify_locations(ca_certs) - if certfile: - context.load_cert_chain(certfile, keyfile, password) - if ciphers: - context.set_ciphers(ciphers) +def check_ssl_params(params): + expected_args = [ + 'use_ssl', + 'ssl_version', + 'ssl_ciphers', + 'ssl_cert_reqs', + 'ssl_keyfile', + 'ssl_keyfile_password', + 'ssl_certfile', + 'ssl_ca_certfile', + ] + for param in params: + if param not in expected_args: + raise ParameterError(( + 'Unexpected parameter for connection initialization: `{}`' + ).format(param)) - _socket = context.wrap_socket(sock=_socket) - return _socket +def create_ssl_context(ssl_params): + if not ssl_params.get('use_ssl'): + return None + + keyfile = ssl_params.get('ssl_keyfile', None) + certfile = ssl_params.get('ssl_certfile', None) + + if keyfile and not certfile: + raise ValueError("certfile must be specified") + + password = ssl_params.get('ssl_keyfile_password', None) + ssl_version = ssl_params.get('ssl_version', SSL_DEFAULT_VERSION) + ciphers = ssl_params.get('ssl_ciphers', SSL_DEFAULT_CIPHERS) + cert_reqs = ssl_params.get('ssl_cert_reqs', ssl.CERT_NONE) + ca_certs = ssl_params.get('ssl_ca_certfile', None) + + context = SSLContext(ssl_version) + context.verify_mode = cert_reqs + + if ca_certs: + context.load_verify_locations(ca_certs) + if certfile: + context.load_cert_chain(certfile, keyfile, password) + if ciphers: + context.set_ciphers(ciphers) + + return context diff --git a/pyignite/cursors.py b/pyignite/cursors.py new file mode 100644 index 0000000..c699556 --- /dev/null +++ b/pyignite/cursors.py @@ -0,0 +1,319 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module contains sync and async cursors for different types of queries. +""" + +import asyncio + +from pyignite.api import ( + scan, scan_cursor_get_page, resource_close, scan_async, scan_cursor_get_page_async, resource_close_async, sql, + sql_cursor_get_page, sql_fields, sql_fields_cursor_get_page, sql_fields_cursor_get_page_async, sql_fields_async +) +from pyignite.exceptions import CacheError, SQLError + + +__all__ = ['ScanCursor', 'SqlCursor', 'SqlFieldsCursor', 'AioScanCursor', 'AioSqlFieldsCursor'] + + +class BaseCursorMixin: + @property + def connection(self): + return getattr(self, '_conn', None) + + @connection.setter + def connection(self, value): + setattr(self, '_conn', value) + + @property + def cursor_id(self): + return getattr(self, '_cursor_id', None) + + @cursor_id.setter + def cursor_id(self, value): + setattr(self, '_cursor_id', value) + + @property + def more(self): + return getattr(self, '_more', None) + + @more.setter + def more(self, value): + setattr(self, '_more', value) + + @property + def cache_id(self): + return getattr(self, '_cache_id', None) + + @cache_id.setter + def cache_id(self, value): + setattr(self, '_cache_id', value) + + @property + def client(self): + return getattr(self, '_client', None) + + @client.setter + def client(self, value): + setattr(self, '_client', value) + + @property + def data(self): + return getattr(self, '_data', None) + + @data.setter + def data(self, value): + setattr(self, '_data', value) + + +class CursorMixin(BaseCursorMixin): + def __enter__(self): + return self + + def __iter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def close(self): + if self.connection and self.cursor_id and self.more: + resource_close(self.connection, self.cursor_id) + + +class AioCursorMixin(BaseCursorMixin): + def __await__(self): + return (yield from self.__aenter__().__await__()) + + def __aiter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + + async def close(self): + if self.connection and self.cursor_id and self.more: + await resource_close_async(self.connection, self.cursor_id) + + +class AbstractScanCursor: + def __init__(self, client, cache_id, page_size, partitions, local): + self.client = client + self.cache_id = cache_id + self._page_size = page_size + self._partitions = partitions + self._local = local + + def _finalize_init(self, result): + if result.status != 0: + raise CacheError(result.message) + + self.cursor_id, self.more = result.value['cursor'], result.value['more'] + self.data = iter(result.value['data'].items()) + + def _process_page_response(self, result): + if result.status != 0: + raise CacheError(result.message) + + self.data, self.more = iter(result.value['data'].items()), result.value['more'] + + +class ScanCursor(AbstractScanCursor, CursorMixin): + def __init__(self, client, cache_id, page_size, partitions, local): + super().__init__(client, cache_id, page_size, partitions, local) + + self.connection = self.client.random_node + result = scan(self.connection, self.cache_id, self._page_size, self._partitions, self._local) + self._finalize_init(result) + + def __next__(self): + if not self.data: + raise StopIteration + + try: + k, v = next(self.data) + except StopIteration: + if self.more: + self._process_page_response(scan_cursor_get_page(self.connection, self.cursor_id)) + k, v = next(self.data) + else: + raise StopIteration + + return self.client.unwrap_binary(k), self.client.unwrap_binary(v) + + +class AioScanCursor(AbstractScanCursor, AioCursorMixin): + def __init__(self, client, cache_id, page_size, partitions, local): + super().__init__(client, cache_id, page_size, partitions, local) + + async def __aenter__(self): + if not self.connection: + self.connection = await self.client.random_node() + result = await scan_async(self.connection, self.cache_id, self._page_size, self._partitions, self._local) + self._finalize_init(result) + return self + + async def __anext__(self): + if not self.connection: + raise CacheError("Using uninitialized cursor, initialize it using async with expression.") + + if not self.data: + raise StopAsyncIteration + + try: + k, v = next(self.data) + except StopIteration: + if self.more: + self._process_page_response(await scan_cursor_get_page_async(self.connection, self.cursor_id)) + try: + k, v = next(self.data) + except StopIteration: + raise StopAsyncIteration + else: + raise StopAsyncIteration + + return await asyncio.gather( + *[self.client.unwrap_binary(k), self.client.unwrap_binary(v)] + ) + + +class SqlCursor(CursorMixin): + def __init__(self, client, cache_id, *args, **kwargs): + self.client = client + self.cache_id = cache_id + self.connection = self.client.random_node + result = sql(self.connection, self.cache_id, *args, **kwargs) + if result.status != 0: + raise SQLError(result.message) + + self.cursor_id, self.more = result.value['cursor'], result.value['more'] + self.data = iter(result.value['data'].items()) + + def __next__(self): + if not self.data: + raise StopIteration + + try: + k, v = next(self.data) + except StopIteration: + if self.more: + result = sql_cursor_get_page(self.connection, self.cursor_id) + if result.status != 0: + raise SQLError(result.message) + self.data, self.more = iter(result.value['data'].items()), result.value['more'] + + k, v = next(self.data) + else: + raise StopIteration + + return self.client.unwrap_binary(k), self.client.unwrap_binary(v) + + +class AbstractSqlFieldsCursor: + def __init__(self, client, cache_id): + self.client = client + self.cache_id = cache_id + + def _finalize_init(self, result): + if result.status != 0: + raise SQLError(result.message) + + self.cursor_id, self.more = result.value['cursor'], result.value['more'] + self.data = iter(result.value['data']) + self._field_names = result.value.get('fields', None) + if self._field_names: + self._field_count = len(self._field_names) + else: + self._field_count = result.value['field_count'] + + +class SqlFieldsCursor(AbstractSqlFieldsCursor, CursorMixin): + def __init__(self, client, cache_id, *args, **kwargs): + super().__init__(client, cache_id) + self.connection = self.client.random_node + self._finalize_init(sql_fields(self.connection, self.cache_id, *args, **kwargs)) + + def __next__(self): + if not self.data: + raise StopIteration + + if self._field_names: + result = self._field_names + self._field_names = None + return result + + try: + row = next(self.data) + except StopIteration: + if self.more: + result = sql_fields_cursor_get_page(self.connection, self.cursor_id, self._field_count) + if result.status != 0: + raise SQLError(result.message) + + self.data, self.more = iter(result.value['data']), result.value['more'] + + row = next(self.data) + else: + raise StopIteration + + return [self.client.unwrap_binary(v) for v in row] + + +class AioSqlFieldsCursor(AbstractSqlFieldsCursor, AioCursorMixin): + def __init__(self, client, cache_id, *args, **kwargs): + super().__init__(client, cache_id) + self._params = (args, kwargs) + + async def __aenter__(self): + await self._initialize(*self._params[0], *self._params[1]) + return self + + async def __anext__(self): + if not self.connection: + raise SQLError("Attempting to use uninitialized aio cursor, please await on it or use with expression.") + + if not self.data: + raise StopAsyncIteration + + if self._field_names: + result = self._field_names + self._field_names = None + return result + + try: + row = next(self.data) + except StopIteration: + if self.more: + result = await sql_fields_cursor_get_page_async(self.connection, self.cursor_id, self._field_count) + if result.status != 0: + raise SQLError(result.message) + + self.data, self.more = iter(result.value['data']), result.value['more'] + try: + row = next(self.data) + except StopIteration: + raise StopAsyncIteration + else: + raise StopAsyncIteration + + return await asyncio.gather(*[self.client.unwrap_binary(v) for v in row]) + + async def _initialize(self, *args, **kwargs): + if self.connection and self.cursor_id: + return + + self.connection = await self.client.random_node() + self._finalize_init(await sql_fields_async(self.connection, self.cache_id, *args, **kwargs)) diff --git a/pyignite/datatypes/__init__.py b/pyignite/datatypes/__init__.py index 49860bd..5024f79 100644 --- a/pyignite/datatypes/__init__.py +++ b/pyignite/datatypes/__init__.py @@ -25,22 +25,3 @@ from .primitive_arrays import * from .primitive_objects import * from .standard import * -from ..stream import BinaryStream, READ_BACKWARD - - -def unwrap_binary(client: 'Client', wrapped: tuple) -> object: - """ - Unwrap wrapped BinaryObject and convert it to Python data. - - :param client: connection to Ignite cluster, - :param wrapped: `WrappedDataObject` value, - :return: dict representing wrapped BinaryObject. - """ - from pyignite.datatypes.complex import BinaryObject - - blob, offset = wrapped - with BinaryStream(client.random_node, blob) as stream: - data_class = BinaryObject.parse(stream) - result = BinaryObject.to_python(stream.read_ctype(data_class, direction=READ_BACKWARD), client) - - return result diff --git a/pyignite/datatypes/base.py b/pyignite/datatypes/base.py index 25b5b1e..fbd798b 100644 --- a/pyignite/datatypes/base.py +++ b/pyignite/datatypes/base.py @@ -47,4 +47,34 @@ class IgniteDataType(metaclass=IgniteDataTypeMeta): This is a base class for all Ignite data types, a.k.a. parser/constructor classes, both object and payload varieties. """ - pass + @classmethod + async def hashcode_async(cls, value, *args, **kwargs): + return cls.hashcode(value, *args, **kwargs) + + @classmethod + def hashcode(cls, value, *args, **kwargs): + return 0 + + @classmethod + def parse(cls, stream): + raise NotImplementedError + + @classmethod + async def parse_async(cls, stream): + return cls.parse(stream) + + @classmethod + def from_python(cls, stream, value, **kwargs): + raise NotImplementedError + + @classmethod + async def from_python_async(cls, stream, value, **kwargs): + cls.from_python(stream, value, **kwargs) + + @classmethod + def to_python(cls, ctype_object, *args, **kwargs): + raise NotImplementedError + + @classmethod + async def to_python_async(cls, ctype_object, *args, **kwargs): + return cls.to_python(ctype_object, *args, **kwargs) diff --git a/pyignite/datatypes/cache_properties.py b/pyignite/datatypes/cache_properties.py index eadaef9..127b6f3 100644 --- a/pyignite/datatypes/cache_properties.py +++ b/pyignite/datatypes/cache_properties.py @@ -23,7 +23,6 @@ from .primitive import * from .standard import * - __all__ = [ 'PropName', 'PropCacheMode', 'PropCacheAtomicityMode', 'PropBackupsNumber', 'PropWriteSynchronizationMode', 'PropCopyOnRead', 'PropReadFromBackup', @@ -81,7 +80,7 @@ class PropBase: @classmethod def build_header(cls): return type( - cls.__name__+'Header', + cls.__name__ + 'Header', (ctypes.LittleEndianStructure,), { '_pack_': 1, @@ -111,11 +110,17 @@ def parse(cls, stream): stream.seek(init_pos + ctypes.sizeof(prop_class)) return prop_class + @classmethod + async def parse_async(cls, stream): + return cls.parse(stream) + @classmethod def to_python(cls, ctype_object, *args, **kwargs): - return cls.prop_data_class.to_python( - ctype_object.data, *args, **kwargs - ) + return cls.prop_data_class.to_python(ctype_object.data, *args, **kwargs) + + @classmethod + async def to_python_async(cls, ctype_object, *args, **kwargs): + return cls.to_python(ctype_object, *args, **kwargs) @classmethod def from_python(cls, stream, value): @@ -125,6 +130,10 @@ def from_python(cls, stream, value): stream.write(bytes(header)) cls.prop_data_class.from_python(stream, value) + @classmethod + async def from_python_async(cls, stream, value): + return cls.from_python(stream, value) + class PropName(PropBase): prop_code = PROP_NAME diff --git a/pyignite/datatypes/complex.py b/pyignite/datatypes/complex.py index b8d9c02..5cb6160 100644 --- a/pyignite/datatypes/complex.py +++ b/pyignite/datatypes/complex.py @@ -12,30 +12,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import asyncio from collections import OrderedDict import ctypes from io import SEEK_CUR -from typing import Iterable, Dict +from typing import Optional from pyignite.constants import * from pyignite.exceptions import ParseError -from .base import IgniteDataType -from .internal import AnyDataObject, infer_from_python +from .internal import AnyDataObject, Struct, infer_from_python, infer_from_python_async from .type_codes import * from .type_ids import * from .type_names import * from .null_object import Null, Nullable +from ..stream import AioBinaryStream, BinaryStream -__all__ = [ - 'Map', 'ObjectArrayObject', 'CollectionObject', 'MapObject', - 'WrappedDataObject', 'BinaryObject', -] - -from ..stream import BinaryStream +__all__ = ['Map', 'ObjectArrayObject', 'CollectionObject', 'MapObject', 'WrappedDataObject', 'BinaryObject'] -class ObjectArrayObject(IgniteDataType, Nullable): +class ObjectArrayObject(Nullable): """ Array of Ignite objects of any consistent type. Its Python representation is tuple(type_id, iterable of any type). The only type ID that makes sense @@ -48,15 +43,10 @@ class ObjectArrayObject(IgniteDataType, Nullable): _type_id = TYPE_OBJ_ARR type_code = TC_OBJECT_ARRAY - @staticmethod - def hashcode(value: Iterable) -> int: - # Arrays are not supported as keys at the moment. - return 0 - @classmethod def build_header(cls): return type( - cls.__name__+'Header', + cls.__name__ + 'Header', (ctypes.LittleEndianStructure,), { '_pack_': 1, @@ -70,16 +60,36 @@ def build_header(cls): @classmethod def parse_not_null(cls, stream): - header_class = cls.build_header() - header = stream.read_ctype(header_class) - stream.seek(ctypes.sizeof(header_class), SEEK_CUR) + header, header_class = cls.__parse_header(stream) fields = [] for i in range(header.length): c_type = AnyDataObject.parse(stream) fields.append(('element_{}'.format(i), c_type)) - final_class = type( + return cls.__build_final_class(header_class, fields) + + @classmethod + async def parse_not_null_async(cls, stream): + header, header_class = cls.__parse_header(stream) + + fields = [] + for i in range(header.length): + c_type = await AnyDataObject.parse_async(stream) + fields.append(('element_{}'.format(i), c_type)) + + return cls.__build_final_class(header_class, fields) + + @classmethod + def __parse_header(cls, stream): + header_class = cls.build_header() + header = stream.read_ctype(header_class) + stream.seek(ctypes.sizeof(header_class), SEEK_CUR) + return header, header_class + + @classmethod + def __build_final_class(cls, header_class, fields): + return type( cls.__name__, (header_class,), { @@ -88,8 +98,6 @@ def parse_not_null(cls, stream): } ) - return final_class - @classmethod def to_python_not_null(cls, ctype_object, *args, **kwargs): result = [] @@ -103,28 +111,55 @@ def to_python_not_null(cls, ctype_object, *args, **kwargs): return ctype_object.type_id, result @classmethod - def from_python_not_null(cls, stream, value): + async def to_python_not_null_async(cls, ctype_object, *args, **kwargs): + result = [ + await AnyDataObject.to_python_async( + getattr(ctype_object, 'element_{}'.format(i)), *args, **kwargs + ) + for i in range(ctype_object.length)] + return ctype_object.type_id, result + + @classmethod + def from_python_not_null(cls, stream, value, *args, **kwargs): + type_or_id, value = value + try: + length = len(value) + except TypeError: + value = [value] + length = 1 + + cls.__write_header(stream, type_or_id, length) + for x in value: + infer_from_python(stream, x) + + @classmethod + async def from_python_not_null_async(cls, stream, value, *args, **kwargs): type_or_id, value = value + try: + length = len(value) + except TypeError: + value = [value] + length = 1 + + cls.__write_header(stream, type_or_id, length) + for x in value: + await infer_from_python_async(stream, x) + + @classmethod + def __write_header(cls, stream, type_or_id, length): header_class = cls.build_header() header = header_class() header.type_code = int.from_bytes( cls.type_code, byteorder=PROTOCOL_BYTE_ORDER ) - try: - length = len(value) - except TypeError: - value = [value] - length = 1 header.length = length header.type_id = type_or_id stream.write(header) - for x in value: - infer_from_python(stream, x) -class WrappedDataObject(IgniteDataType, Nullable): +class WrappedDataObject(Nullable): """ One or more binary objects can be wrapped in an array. This allows reading, storing, passing and writing objects efficiently without understanding @@ -138,7 +173,7 @@ class WrappedDataObject(IgniteDataType, Nullable): @classmethod def build_header(cls): return type( - cls.__name__+'Header', + cls.__name__ + 'Header', (ctypes.LittleEndianStructure,), { '_pack_': 1, @@ -160,7 +195,7 @@ def parse_not_null(cls, stream): { '_pack_': 1, '_fields_': [ - ('payload', ctypes.c_byte*header.length), + ('payload', ctypes.c_byte * header.length), ('offset', ctypes.c_int), ], } @@ -170,15 +205,15 @@ def parse_not_null(cls, stream): return final_class @classmethod - def to_python(cls, ctype_object, *args, **kwargs): + def to_python_not_null(cls, ctype_object, *args, **kwargs): return bytes(ctype_object.payload), ctype_object.offset @classmethod - def from_python(cls, stream, value): + def from_python(cls, stream, value, *args, **kwargs): raise ParseError('Send unwrapped data.') -class CollectionObject(IgniteDataType, Nullable): +class CollectionObject(Nullable): """ Similar to object array, but contains platform-agnostic deserialization type hint instead of type ID. @@ -220,15 +255,10 @@ class CollectionObject(IgniteDataType, Nullable): pythonic = list default = [] - @staticmethod - def hashcode(value: Iterable) -> int: - # Collections are not supported as keys at the moment. - return 0 - @classmethod def build_header(cls): return type( - cls.__name__+'Header', + cls.__name__ + 'Header', (ctypes.LittleEndianStructure,), { '_pack_': 1, @@ -242,16 +272,36 @@ def build_header(cls): @classmethod def parse_not_null(cls, stream): - header_class = cls.build_header() - header = stream.read_ctype(header_class) - stream.seek(ctypes.sizeof(header_class), SEEK_CUR) + header, header_class = cls.__parse_header(stream) fields = [] for i in range(header.length): c_type = AnyDataObject.parse(stream) fields.append(('element_{}'.format(i), c_type)) - final_class = type( + return cls.__build_final_class(header_class, fields) + + @classmethod + async def parse_not_null_async(cls, stream): + header, header_class = cls.__parse_header(stream) + + fields = [] + for i in range(header.length): + c_type = await AnyDataObject.parse_async(stream) + fields.append(('element_{}'.format(i), c_type)) + + return cls.__build_final_class(header_class, fields) + + @classmethod + def __parse_header(cls, stream): + header_class = cls.build_header() + header = stream.read_ctype(header_class) + stream.seek(ctypes.sizeof(header_class), SEEK_CUR) + return header, header_class + + @classmethod + def __build_final_class(cls, header_class, fields): + return type( cls.__name__, (header_class,), { @@ -259,46 +309,78 @@ def parse_not_null(cls, stream): '_fields_': fields, } ) - return final_class @classmethod def to_python(cls, ctype_object, *args, **kwargs): - result = [] - length = getattr(ctype_object, "length", None) + length = cls.__get_length(ctype_object) if length is None: return None - for i in range(length): - result.append( - AnyDataObject.to_python( - getattr(ctype_object, 'element_{}'.format(i)), - *args, **kwargs - ) - ) + + result = [ + AnyDataObject.to_python(getattr(ctype_object, f'element_{i}'), *args, **kwargs) + for i in range(length) + ] return ctype_object.type, result @classmethod - def from_python_not_null(cls, stream, value): + async def to_python_async(cls, ctype_object, *args, **kwargs): + length = cls.__get_length(ctype_object) + if length is None: + return None + + result_coro = [ + AnyDataObject.to_python_async(getattr(ctype_object, f'element_{i}'), *args, **kwargs) + for i in range(length) + ] + + return ctype_object.type, await asyncio.gather(*result_coro) + + @classmethod + def __get_length(cls, ctype_object): + return getattr(ctype_object, "length", None) + + @classmethod + def from_python_not_null(cls, stream, value, *args, **kwargs): type_or_id, value = value + try: + length = len(value) + except TypeError: + value = [value] + length = 1 + + cls.__write_header(stream, type_or_id, length) + for x in value: + infer_from_python(stream, x) + + @classmethod + async def from_python_not_null_async(cls, stream, value, *args, **kwargs): + type_or_id, value = value + try: + length = len(value) + except TypeError: + value = [value] + length = 1 + + cls.__write_header(stream, type_or_id, length) + for x in value: + await infer_from_python_async(stream, x) + + @classmethod + def __write_header(cls, stream, type_or_id, length): header_class = cls.build_header() header = header_class() header.type_code = int.from_bytes( cls.type_code, byteorder=PROTOCOL_BYTE_ORDER ) - try: - length = len(value) - except TypeError: - value = [value] - length = 1 + header.length = length header.type = type_or_id stream.write(header) - for x in value: - infer_from_python(stream, x) -class Map(IgniteDataType, Nullable): +class Map(Nullable): """ Dictionary type, payload-only. @@ -310,15 +392,10 @@ class Map(IgniteDataType, Nullable): HASH_MAP = 1 LINKED_HASH_MAP = 2 - @staticmethod - def hashcode(value: Dict) -> int: - # Maps are not supported as keys at the moment. - return 0 - @classmethod def build_header(cls): return type( - cls.__name__+'Header', + cls.__name__ + 'Header', (ctypes.LittleEndianStructure,), { '_pack_': 1, @@ -330,16 +407,36 @@ def build_header(cls): @classmethod def parse_not_null(cls, stream): - header_class = cls.build_header() - header = stream.read_ctype(header_class) - stream.seek(ctypes.sizeof(header_class), SEEK_CUR) + header, header_class = cls.__parse_header(stream) fields = [] for i in range(header.length << 1): c_type = AnyDataObject.parse(stream) fields.append(('element_{}'.format(i), c_type)) - final_class = type( + return cls.__build_final_class(header_class, fields) + + @classmethod + async def parse_not_null_async(cls, stream): + header, header_class = cls.__parse_header(stream) + + fields = [] + for i in range(header.length << 1): + c_type = await AnyDataObject.parse_async(stream) + fields.append(('element_{}'.format(i), c_type)) + + return cls.__build_final_class(header_class, fields) + + @classmethod + def __parse_header(cls, stream): + header_class = cls.build_header() + header = stream.read_ctype(header_class) + stream.seek(ctypes.sizeof(header_class), SEEK_CUR) + return header, header_class + + @classmethod + def __build_final_class(cls, header_class, fields): + return type( cls.__name__, (header_class,), { @@ -347,43 +444,75 @@ def parse_not_null(cls, stream): '_fields_': fields, } ) - return final_class @classmethod def to_python(cls, ctype_object, *args, **kwargs): - map_type = getattr(ctype_object, 'type', cls.HASH_MAP) - result = OrderedDict() if map_type == cls.LINKED_HASH_MAP else {} + map_cls = cls.__get_map_class(ctype_object) + result = map_cls() for i in range(0, ctype_object.length << 1, 2): k = AnyDataObject.to_python( + getattr(ctype_object, 'element_{}'.format(i)), + *args, **kwargs + ) + v = AnyDataObject.to_python( + getattr(ctype_object, 'element_{}'.format(i + 1)), + *args, **kwargs + ) + result[k] = v + return result + + @classmethod + async def to_python_async(cls, ctype_object, *args, **kwargs): + map_cls = cls.__get_map_class(ctype_object) + + kv_pairs_coro = [ + asyncio.gather( + AnyDataObject.to_python_async( getattr(ctype_object, 'element_{}'.format(i)), *args, **kwargs - ) - v = AnyDataObject.to_python( + ), + AnyDataObject.to_python_async( getattr(ctype_object, 'element_{}'.format(i + 1)), *args, **kwargs ) - result[k] = v - return result + ) for i in range(0, ctype_object.length << 1, 2) + ] + + return map_cls(await asyncio.gather(*kv_pairs_coro)) + + @classmethod + def __get_map_class(cls, ctype_object): + map_type = getattr(ctype_object, 'type', cls.HASH_MAP) + return OrderedDict if map_type == cls.LINKED_HASH_MAP else dict @classmethod def from_python(cls, stream, value, type_id=None): + cls.__write_header(stream, type_id, len(value)) + for k, v in value.items(): + infer_from_python(stream, k) + infer_from_python(stream, v) + + @classmethod + async def from_python_async(cls, stream, value, type_id=None): + cls.__write_header(stream, type_id, len(value)) + for k, v in value.items(): + await infer_from_python_async(stream, k) + await infer_from_python_async(stream, v) + + @classmethod + def __write_header(cls, stream, type_id, length): header_class = cls.build_header() header = header_class() - length = len(value) header.length = length + if hasattr(header, 'type_code'): - header.type_code = int.from_bytes( - cls.type_code, - byteorder=PROTOCOL_BYTE_ORDER - ) + header.type_code = int.from_bytes(cls.type_code, byteorder=PROTOCOL_BYTE_ORDER) + if hasattr(header, 'type'): header.type = type_id stream.write(header) - for k, v in value.items(): - infer_from_python(stream, k) - infer_from_python(stream, v) class MapObject(Map): @@ -404,7 +533,7 @@ class MapObject(Map): @classmethod def build_header(cls): return type( - cls.__name__+'Header', + cls.__name__ + 'Header', (ctypes.LittleEndianStructure,), { '_pack_': 1, @@ -419,23 +548,43 @@ def build_header(cls): @classmethod def to_python(cls, ctype_object, *args, **kwargs): obj_type = getattr(ctype_object, "type", None) - if obj_type is None: - return None - return obj_type, super().to_python( - ctype_object, *args, **kwargs - ) + if obj_type: + return obj_type, super().to_python(ctype_object, *args, **kwargs) + return None + + @classmethod + async def to_python_async(cls, ctype_object, *args, **kwargs): + obj_type = getattr(ctype_object, "type", None) + if obj_type: + return obj_type, await super().to_python_async(ctype_object, *args, **kwargs) + return None + + @classmethod + def __get_obj_type(cls, ctype_object): + return getattr(ctype_object, "type", None) + + @classmethod + def from_python(cls, stream, value, **kwargs): + type_id, value = cls.__unpack_value(stream, value) + if value: + super().from_python(stream, value, type_id) @classmethod - def from_python(cls, stream, value): + async def from_python_async(cls, stream, value, **kwargs): + type_id, value = cls.__unpack_value(stream, value) + if value: + await super().from_python_async(stream, value, type_id) + + @classmethod + def __unpack_value(cls, stream, value): if value is None: Null.from_python(stream) - return + return None, None - type_id, value = value - super().from_python(stream, value, type_id) + return value -class BinaryObject(IgniteDataType, Nullable): +class BinaryObject(Nullable): _type_id = TYPE_BINARY_OBJ type_code = TC_COMPLEX_OBJECT @@ -446,18 +595,25 @@ class BinaryObject(IgniteDataType, Nullable): OFFSET_TWO_BYTES = 0x0010 COMPACT_FOOTER = 0x0020 - @staticmethod - def hashcode(value: object, client: None) -> int: + @classmethod + def hashcode(cls, value: object, client: Optional['Client']) -> int: # binary objects's hashcode implementation is special in the sense # that you need to fully serialize the object to calculate # its hashcode - if not value._hashcode and client : - - with BinaryStream(client.random_node) as stream: + if not value._hashcode and client: + with BinaryStream(client) as stream: value._from_python(stream, save_to_buf=True) return value._hashcode + @classmethod + async def hashcode_async(cls, value: object, client: Optional['AioClient']) -> int: + if not value._hashcode and client: + with AioBinaryStream(client) as stream: + await value._from_python_async(stream, save_to_buf=True) + + return value._hashcode + @classmethod def build_header(cls): return type( @@ -504,22 +660,47 @@ def schema_type(cls, flags: int): @classmethod def parse_not_null(cls, stream): - from pyignite.datatypes import Struct + header, header_class = cls.__parse_header(stream) + + # ignore full schema, always retrieve fields' types and order + # from complex types registry + data_class = stream.get_dataclass(header) + object_fields_struct = cls.__build_object_fields_struct(data_class) + object_fields = object_fields_struct.parse(stream) + + return cls.__build_final_class(stream, header, header_class, object_fields, + len(object_fields_struct.fields)) + @classmethod + async def parse_not_null_async(cls, stream): + header, header_class = cls.__parse_header(stream) + + # ignore full schema, always retrieve fields' types and order + # from complex types registry + data_class = await stream.get_dataclass(header) + object_fields_struct = cls.__build_object_fields_struct(data_class) + object_fields = await object_fields_struct.parse_async(stream) + + return cls.__build_final_class(stream, header, header_class, object_fields, + len(object_fields_struct.fields)) + + @classmethod + def __parse_header(cls, stream): header_class = cls.build_header() header = stream.read_ctype(header_class) stream.seek(ctypes.sizeof(header_class), SEEK_CUR) + return header, header_class - # ignore full schema, always retrieve fields' types and order - # from complex types registry - data_class = stream.get_dataclass(header) + @staticmethod + def __build_object_fields_struct(data_class): fields = data_class.schema.items() - object_fields_struct = Struct(fields) - object_fields = object_fields_struct.parse(stream) - final_class_fields = [('object_fields', object_fields)] + return Struct(fields) + @classmethod + def __build_final_class(cls, stream, header, header_class, object_fields, fields_len): + final_class_fields = [('object_fields', object_fields)] if header.flags & cls.HAS_SCHEMA: - schema = cls.schema_type(header.flags) * len(fields) + schema = cls.schema_type(header.flags) * fields_len stream.seek(ctypes.sizeof(schema), SEEK_CUR) final_class_fields.append(('schema', schema)) @@ -537,35 +718,71 @@ def parse_not_null(cls, stream): @classmethod def to_python(cls, ctype_object, client: 'Client' = None, *args, **kwargs): - type_id = getattr(ctype_object, "type_id", None) - if type_id is None: - return None + type_id = cls.__get_type_id(ctype_object, client) + if type_id: + data_class = client.query_binary_type(type_id, ctype_object.schema_id) + + result = data_class() + result.version = ctype_object.version + for field_name, field_type in data_class.schema.items(): + setattr( + result, field_name, field_type.to_python( + getattr(ctype_object.object_fields, field_name), + client, *args, **kwargs + ) + ) + return result - if not client: - raise ParseError( - 'Can not query binary type {}'.format(type_id) - ) + return None - data_class = client.query_binary_type( - type_id, - ctype_object.schema_id - ) - result = data_class() - - result.version = ctype_object.version - for field_name, field_type in data_class.schema.items(): - setattr( - result, field_name, field_type.to_python( - getattr(ctype_object.object_fields, field_name), - client, *args, **kwargs - ) + @classmethod + async def to_python_async(cls, ctype_object, client: 'AioClient' = None, *args, **kwargs): + type_id = cls.__get_type_id(ctype_object, client) + if type_id: + data_class = await client.query_binary_type(type_id, ctype_object.schema_id) + + result = data_class() + result.version = ctype_object.version + + field_values = await asyncio.gather( + *[ + field_type.to_python_async( + getattr(ctype_object.object_fields, field_name), client, *args, **kwargs + ) + for field_name, field_type in data_class.schema.items() + ] ) - return result + + for i, field_name in enumerate(data_class.schema.keys()): + setattr(result, field_name, field_values[i]) + + return result + return None @classmethod - def from_python_not_null(cls, stream, value): - if getattr(value, '_buffer', None): - stream.write(value._buffer) - else: + def __get_type_id(cls, ctype_object, client): + type_id = getattr(ctype_object, "type_id", None) + if type_id: + if not client: + raise ParseError(f'Can not query binary type {type_id}') + return type_id + return None + + @classmethod + def from_python_not_null(cls, stream, value, **kwargs): + if cls.__write_fast_path(stream, value): stream.register_binary_type(value.__class__) value._from_python(stream) + + @classmethod + async def from_python_not_null_async(cls, stream, value, **kwargs): + if cls.__write_fast_path(stream, value): + await stream.register_binary_type(value.__class__) + await value._from_python_async(stream) + + @classmethod + def __write_fast_path(cls, stream, value): + if getattr(value, '_buffer', None): + stream.write(value._buffer) + return False + return True diff --git a/pyignite/datatypes/internal.py b/pyignite/datatypes/internal.py index a6da9fe..0de50e2 100644 --- a/pyignite/datatypes/internal.py +++ b/pyignite/datatypes/internal.py @@ -12,26 +12,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import asyncio from collections import OrderedDict import ctypes import decimal from datetime import date, datetime, timedelta from io import SEEK_CUR -from typing import Any, Tuple, Union, Callable, List +from typing import Any, Union, Callable, List import uuid import attr -from pyignite.constants import * +from pyignite.constants import PROTOCOL_BYTE_ORDER from pyignite.exceptions import ParseError from pyignite.utils import is_binary, is_hinted, is_iterable from .type_codes import * __all__ = [ - 'AnyDataArray', 'AnyDataObject', 'Struct', 'StructArray', 'tc_map', - 'infer_from_python', + 'AnyDataArray', 'AnyDataObject', 'Struct', 'StructArray', 'tc_map', 'infer_from_python', 'infer_from_python_async' ] from ..stream import READ_BACKWARD @@ -124,11 +123,25 @@ def __init__(self, fields: List, predicate1: Callable[[any], bool], self.var2 = var2 def parse(self, stream, context): - return self.var1.parse(stream) if self.predicate1(context) else self.var2.parse(stream) + if self.predicate1(context): + return self.var1.parse(stream) + return self.var2.parse(stream) + + async def parse_async(self, stream, context): + if self.predicate1(context): + return await self.var1.parse_async(stream) + return await self.var2.parse_async(stream) def to_python(self, ctype_object, context, *args, **kwargs): - return self.var1.to_python(ctype_object, *args, **kwargs) if self.predicate2(context)\ - else self.var2.to_python(ctype_object, *args, **kwargs) + if self.predicate2(context): + return self.var1.to_python(ctype_object, *args, **kwargs) + return self.var2.to_python(ctype_object, *args, **kwargs) + + async def to_python_async(self, ctype_object, context, *args, **kwargs): + if self.predicate2(context): + return await self.var1.to_python_async(ctype_object, *args, **kwargs) + return await self.var2.to_python_async(ctype_object, *args, **kwargs) + @attr.s class StructArray: @@ -139,7 +152,7 @@ class StructArray: def build_header_class(self): return type( - self.__class__.__name__+'Header', + self.__class__.__name__ + 'Header', (ctypes.LittleEndianStructure,), { '_pack_': 1, @@ -150,19 +163,34 @@ def build_header_class(self): ) def parse(self, stream): + fields, length = [], self.__parse_length(stream) + + for i in range(length): + c_type = Struct(self.following).parse(stream) + fields.append(('element_{}'.format(i), c_type)) + + return self.__build_final_class(fields) + + async def parse_async(self, stream): + fields, length = [], self.__parse_length(stream) + + for i in range(length): + c_type = await Struct(self.following).parse_async(stream) + fields.append(('element_{}'.format(i), c_type)) + + return self.__build_final_class(fields) + + def __parse_length(self, stream): counter_type_len = ctypes.sizeof(self.counter_type) length = int.from_bytes( stream.mem_view(offset=counter_type_len), byteorder=PROTOCOL_BYTE_ORDER ) stream.seek(counter_type_len, SEEK_CUR) + return length - fields = [] - for i in range(length): - c_type = Struct(self.following).parse(stream) - fields.append(('element_{}'.format(i), c_type)) - - data_class = type( + def __build_final_class(self, fields): + return type( 'StructArray', (self.build_header_class(),), { @@ -171,36 +199,47 @@ def parse(self, stream): }, ) - return data_class - def to_python(self, ctype_object, *args, **kwargs): - result = [] length = getattr(ctype_object, 'length', 0) - for i in range(length): - result.append( - Struct( - self.following, dict_type=dict - ).to_python( - getattr(ctype_object, 'element_{}'.format(i)), - *args, **kwargs - ) - ) - return result + return [ + Struct(self.following, dict_type=dict).to_python(getattr(ctype_object, 'element_{}'.format(i)), + *args, **kwargs) + for i in range(length) + ] - def from_python(self, stream, value): - length = len(value) - header_class = self.build_header_class() - header = header_class() - header.length = length + async def to_python_async(self, ctype_object, *args, **kwargs): + length = getattr(ctype_object, 'length', 0) + result_coro = [ + Struct(self.following, dict_type=dict).to_python_async(getattr(ctype_object, 'element_{}'.format(i)), + *args, **kwargs) + for i in range(length) + ] + return await asyncio.gather(*result_coro) + def from_python(self, stream, value): + self.__write_header(stream, len(value)) - stream.write(header) - for i, v in enumerate(value): + for v in value: for default_key, default_value in self.defaults.items(): v.setdefault(default_key, default_value) for name, el_class in self.following: el_class.from_python(stream, v[name]) + async def from_python_async(self, stream, value): + self.__write_header(stream, len(value)) + + for v in value: + for default_key, default_value in self.defaults.items(): + v.setdefault(default_key, default_value) + for name, el_class in self.following: + await el_class.from_python_async(stream, v[name]) + + def __write_header(self, stream, length): + header_class = self.build_header_class() + header = header_class() + header.length = length + stream.write(header) + @attr.s class Struct: @@ -210,12 +249,7 @@ class Struct: defaults = attr.ib(type=dict, default={}) def parse(self, stream): - fields, ctx = [], {} - - for _, c_type in self.fields: - if isinstance(c_type, Conditional): - for name in c_type.fields: - ctx[name] = None + fields, ctx = [], self.__prepare_conditional_ctx() for name, c_type in self.fields: is_cond = isinstance(c_type, Conditional) @@ -224,7 +258,31 @@ def parse(self, stream): if name in ctx: ctx[name] = stream.read_ctype(c_type, direction=READ_BACKWARD) - data_class = type( + return self.__build_final_class(fields) + + async def parse_async(self, stream): + fields, ctx = [], self.__prepare_conditional_ctx() + + for name, c_type in self.fields: + is_cond = isinstance(c_type, Conditional) + c_type = await c_type.parse_async(stream, ctx) if is_cond else await c_type.parse_async(stream) + fields.append((name, c_type)) + if name in ctx: + ctx[name] = stream.read_ctype(c_type, direction=READ_BACKWARD) + + return self.__build_final_class(fields) + + def __prepare_conditional_ctx(self): + ctx = {} + for _, c_type in self.fields: + if isinstance(c_type, Conditional): + for name in c_type.fields: + ctx[name] = None + return ctx + + @staticmethod + def __build_final_class(fields): + return type( 'Struct', (ctypes.LittleEndianStructure,), { @@ -233,11 +291,7 @@ def parse(self, stream): }, ) - return data_class - - def to_python( - self, ctype_object, *args, **kwargs - ) -> Union[dict, OrderedDict]: + def to_python(self, ctype_object, *args, **kwargs) -> Union[dict, OrderedDict]: result = self.dict_type() for name, c_type in self.fields: is_cond = isinstance(c_type, Conditional) @@ -251,13 +305,41 @@ def to_python( ) return result + async def to_python_async(self, ctype_object, *args, **kwargs) -> Union[dict, OrderedDict]: + result = self.dict_type() + for name, c_type in self.fields: + is_cond = isinstance(c_type, Conditional) + + if is_cond: + value = await c_type.to_python_async( + getattr(ctype_object, name), + result, + *args, **kwargs + ) + else: + value = await c_type.to_python_async( + getattr(ctype_object, name), + *args, **kwargs + ) + result[name] = value + return result + def from_python(self, stream, value): - for default_key, default_value in self.defaults.items(): - value.setdefault(default_key, default_value) + self.__set_defaults(value) for name, el_class in self.fields: el_class.from_python(stream, value[name]) + async def from_python_async(self, stream, value): + self.__set_defaults(value) + + for name, el_class in self.fields: + await el_class.from_python_async(stream, value[name]) + + def __set_defaults(self, value): + for default_key, default_value in self.defaults.items(): + value.setdefault(default_key, default_value) + class AnyDataObject: """ @@ -294,29 +376,44 @@ def get_subtype(iterable, allow_none=False): # if an iterable contains items of more than one non-nullable type, # return None - if all([ - isinstance(x, type_first) - or ((x is None) and allow_none) for x in iterator - ]): + if all(isinstance(x, type_first) or ((x is None) and allow_none) for x in iterator): return type_first @classmethod def parse(cls, stream): + data_class = cls.__data_class_parse(stream) + return data_class.parse(stream) + + @classmethod + async def parse_async(cls, stream): + data_class = cls.__data_class_parse(stream) + return await data_class.parse_async(stream) + + @classmethod + def __data_class_parse(cls, stream): type_code = bytes(stream.mem_view(offset=ctypes.sizeof(ctypes.c_byte))) try: - data_class = tc_map(type_code) + return tc_map(type_code) except KeyError: raise ParseError('Unknown type code: `{}`'.format(type_code)) - return data_class.parse(stream) @classmethod def to_python(cls, ctype_object, *args, **kwargs): + data_class = cls.__data_class_from_ctype(ctype_object) + return data_class.to_python(ctype_object) + + @classmethod + async def to_python_async(cls, ctype_object, *args, **kwargs): + data_class = cls.__data_class_from_ctype(ctype_object) + return await data_class.to_python_async(ctype_object) + + @classmethod + def __data_class_from_ctype(cls, ctype_object): type_code = ctype_object.type_code.to_bytes( ctypes.sizeof(ctypes.c_byte), byteorder=PROTOCOL_BYTE_ORDER ) - data_class = tc_map(type_code) - return data_class.to_python(ctype_object) + return tc_map(type_code) @classmethod def _init_python_map(cls): @@ -423,6 +520,11 @@ def from_python(cls, stream, value): p_type = cls.map_python_type(value) p_type.from_python(stream, value) + @classmethod + async def from_python_async(cls, stream, value): + p_type = cls.map_python_type(value) + await p_type.from_python_async(stream, value) + def infer_from_python(stream, value: Any): """ @@ -431,14 +533,26 @@ def infer_from_python(stream, value: Any): :param value: pythonic value or (value, type_hint) tuple, :return: bytes. """ - if is_hinted(value): - value, data_type = value - else: - data_type = AnyDataObject + value, data_type = __unpack_hinted(value) data_type.from_python(stream, value) +async def infer_from_python_async(stream, value: Any): + """ + Async version of infer_from_python + """ + value, data_type = __unpack_hinted(value) + + await data_type.from_python_async(stream, value) + + +def __unpack_hinted(value): + if is_hinted(value): + return value + return value, AnyDataObject + + @attr.s class AnyDataArray(AnyDataObject): """ @@ -448,7 +562,7 @@ class AnyDataArray(AnyDataObject): def build_header(self): return type( - self.__class__.__name__+'Header', + self.__class__.__name__ + 'Header', (ctypes.LittleEndianStructure,), { '_pack_': 1, @@ -459,16 +573,33 @@ def build_header(self): ) def parse(self, stream): - header_class = self.build_header() - header = stream.read_ctype(header_class) - stream.seek(ctypes.sizeof(header_class), SEEK_CUR) + header, header_class = self.__parse_header(stream) fields = [] for i in range(header.length): c_type = super().parse(stream) fields.append(('element_{}'.format(i), c_type)) - final_class = type( + return self.__build_final_class(header_class, fields) + + async def parse_async(self, stream): + header, header_class = self.__parse_header(stream) + + fields = [] + for i in range(header.length): + c_type = await super().parse_async(stream) + fields.append(('element_{}'.format(i), c_type)) + + return self.__build_final_class(header_class, fields) + + def __parse_header(self, stream): + header_class = self.build_header() + header = stream.read_ctype(header_class) + stream.seek(ctypes.sizeof(header_class), SEEK_CUR) + return header, header_class + + def __build_final_class(self, header_class, fields): + return type( self.__class__.__name__, (header_class,), { @@ -476,34 +607,58 @@ def parse(self, stream): '_fields_': fields, } ) - return final_class @classmethod def to_python(cls, ctype_object, *args, **kwargs): - result = [] - length = getattr(ctype_object, "length", None) - if length is None: - return None - for i in range(length): - result.append( + length = cls.__get_length(ctype_object) + + return [ + super().to_python(getattr(ctype_object, 'element_{}'.format(i)), *args, **kwargs) + for i in range(length) + ] + + @classmethod + async def to_python_async(cls, ctype_object, *args, **kwargs): + length = cls.__get_length(ctype_object) + + values = asyncio.gather( + *[ super().to_python( getattr(ctype_object, 'element_{}'.format(i)), *args, **kwargs - ) - ) - return result + ) for i in range(length) + ] + ) + return await values - def from_python(self, stream, value): - header_class = self.build_header() - header = header_class() + @staticmethod + def __get_length(ctype_object): + return getattr(ctype_object, "length", None) + def from_python(self, stream, value): try: length = len(value) except TypeError: value = [value] length = 1 - header.length = length + self.__write_header(stream, length) - stream.write(header) for x in value: infer_from_python(stream, x) + + async def from_python_async(self, stream, value): + try: + length = len(value) + except TypeError: + value = [value] + length = 1 + self.__write_header(stream, length) + + for x in value: + await infer_from_python_async(stream, x) + + def __write_header(self, stream, length): + header_class = self.build_header() + header = header_class() + header.length = length + stream.write(header) diff --git a/pyignite/datatypes/null_object.py b/pyignite/datatypes/null_object.py index 912ded8..f16034f 100644 --- a/pyignite/datatypes/null_object.py +++ b/pyignite/datatypes/null_object.py @@ -21,13 +21,12 @@ import ctypes from io import SEEK_CUR -from typing import Any from .base import IgniteDataType from .type_codes import TC_NULL -__all__ = ['Null'] +__all__ = ['Null', 'Nullable'] from ..constants import PROTOCOL_BYTE_ORDER @@ -37,11 +36,6 @@ class Null(IgniteDataType): pythonic = type(None) _object_c_type = None - @staticmethod - def hashcode(value: Any) -> int: - # Null object can not be a cache key. - return 0 - @classmethod def build_c_type(cls): if cls._object_c_type is None: @@ -59,55 +53,99 @@ def build_c_type(cls): @classmethod def parse(cls, stream): - init_pos, offset = stream.tell(), ctypes.sizeof(ctypes.c_byte) - stream.seek(offset, SEEK_CUR) + stream.seek(ctypes.sizeof(ctypes.c_byte), SEEK_CUR) return cls.build_c_type() - @staticmethod - def to_python(*args, **kwargs): + @classmethod + def to_python(cls, *args, **kwargs): return None - @staticmethod - def from_python(stream, *args): + @classmethod + def from_python(cls, stream, *args): stream.write(TC_NULL) -class Nullable: +class Nullable(IgniteDataType): @classmethod def parse_not_null(cls, stream): raise NotImplementedError + @classmethod + async def parse_not_null_async(cls, stream): + return cls.parse_not_null(stream) + @classmethod def parse(cls, stream): - type_len = ctypes.sizeof(ctypes.c_byte) + is_null, null_type = cls.__check_null_input(stream) - if stream.mem_view(offset=type_len) == TC_NULL: - stream.seek(type_len, SEEK_CUR) - return Null.build_c_type() + if is_null: + return null_type return cls.parse_not_null(stream) + @classmethod + async def parse_async(cls, stream): + is_null, null_type = cls.__check_null_input(stream) + + if is_null: + return null_type + + return await cls.parse_not_null_async(stream) + + @classmethod + def from_python_not_null(cls, stream, value, **kwargs): + raise NotImplementedError + + @classmethod + async def from_python_not_null_async(cls, stream, value, **kwargs): + return cls.from_python_not_null(stream, value, **kwargs) + + @classmethod + def from_python(cls, stream, value, **kwargs): + if value is None: + Null.from_python(stream) + else: + cls.from_python_not_null(stream, value) + + @classmethod + async def from_python_async(cls, stream, value, **kwargs): + if value is None: + Null.from_python(stream) + else: + await cls.from_python_not_null_async(stream, value, **kwargs) + @classmethod def to_python_not_null(cls, ctypes_object, *args, **kwargs): raise NotImplementedError + @classmethod + async def to_python_not_null_async(cls, ctypes_object, *args, **kwargs): + return cls.to_python_not_null(ctypes_object, *args, **kwargs) + @classmethod def to_python(cls, ctypes_object, *args, **kwargs): - if ctypes_object.type_code == int.from_bytes( - TC_NULL, - byteorder=PROTOCOL_BYTE_ORDER - ): + if cls.__is_null(ctypes_object): return None return cls.to_python_not_null(ctypes_object, *args, **kwargs) @classmethod - def from_python_not_null(cls, stream, value): - raise NotImplementedError + async def to_python_async(cls, ctypes_object, *args, **kwargs): + if cls.__is_null(ctypes_object): + return None + + return await cls.to_python_not_null_async(ctypes_object, *args, **kwargs) @classmethod - def from_python(cls, stream, value): - if value is None: - Null.from_python(stream) - else: - cls.from_python_not_null(stream, value) + def __check_null_input(cls, stream): + type_len = ctypes.sizeof(ctypes.c_byte) + + if stream.mem_view(offset=type_len) == TC_NULL: + stream.seek(type_len, SEEK_CUR) + return True, Null.build_c_type() + + return False, None + + @classmethod + def __is_null(cls, ctypes_object): + return ctypes_object.type_code == int.from_bytes(TC_NULL, byteorder=PROTOCOL_BYTE_ORDER) diff --git a/pyignite/datatypes/primitive.py b/pyignite/datatypes/primitive.py index ffa2e32..3bbb196 100644 --- a/pyignite/datatypes/primitive.py +++ b/pyignite/datatypes/primitive.py @@ -48,8 +48,7 @@ class Primitive(IgniteDataType): @classmethod def parse(cls, stream): - init_pos, offset = stream.tell(), ctypes.sizeof(cls.c_type) - stream.seek(offset, SEEK_CUR) + stream.seek(ctypes.sizeof(cls.c_type), SEEK_CUR) return cls.c_type @classmethod diff --git a/pyignite/datatypes/primitive_arrays.py b/pyignite/datatypes/primitive_arrays.py index 7cb5b20..a21de77 100644 --- a/pyignite/datatypes/primitive_arrays.py +++ b/pyignite/datatypes/primitive_arrays.py @@ -15,11 +15,8 @@ import ctypes from io import SEEK_CUR -from typing import Any from pyignite.constants import * -from . import Null -from .base import IgniteDataType from .null_object import Nullable from .primitive import * from .type_codes import * @@ -35,7 +32,7 @@ ] -class PrimitiveArray(IgniteDataType, Nullable): +class PrimitiveArray(Nullable): """ Base class for array of primitives. Payload-only. """ @@ -44,15 +41,10 @@ class PrimitiveArray(IgniteDataType, Nullable): primitive_type = None type_code = None - @staticmethod - def hashcode(value: Any) -> int: - # Arrays are not supported as keys at the moment. - return 0 - @classmethod def build_header_class(cls): return type( - cls.__name__+'Header', + cls.__name__ + 'Header', (ctypes.LittleEndianStructure,), { '_pack_': 1, @@ -88,7 +80,11 @@ def to_python(cls, ctype_object, *args, **kwargs): return [ctype_object.data[i] for i in range(ctype_object.length)] @classmethod - def from_python_not_null(cls, stream, value): + async def to_python_async(cls, ctypes_object, *args, **kwargs): + return cls.to_python(ctypes_object, *args, **kwargs) + + @classmethod + def from_python_not_null(cls, stream, value, **kwargs): header_class = cls.build_header_class() header = header_class() if hasattr(header, 'type_code'): @@ -188,7 +184,7 @@ class PrimitiveArrayObject(PrimitiveArray): @classmethod def build_header_class(cls): return type( - cls.__name__+'Header', + cls.__name__ + 'Header', (ctypes.LittleEndianStructure,), { '_pack_': 1, @@ -312,7 +308,5 @@ def to_python(cls, ctype_object, *args, **kwargs): length = getattr(ctype_object, "length", None) if length is None: return None - result = [False] * length - for i in range(length): - result[i] = ctype_object.data[i] != 0 - return result + + return [ctype_object.data[i] != 0 for i in range(length)] diff --git a/pyignite/datatypes/primitive_objects.py b/pyignite/datatypes/primitive_objects.py index e942dd7..5849935 100644 --- a/pyignite/datatypes/primitive_objects.py +++ b/pyignite/datatypes/primitive_objects.py @@ -18,11 +18,10 @@ from pyignite.constants import * from pyignite.utils import unsigned -from .base import IgniteDataType from .type_codes import * from .type_ids import * from .type_names import * -from .null_object import Null, Nullable +from .null_object import Nullable __all__ = [ 'DataObject', 'ByteObject', 'ShortObject', 'IntObject', 'LongObject', @@ -30,7 +29,7 @@ ] -class DataObject(IgniteDataType, Nullable): +class DataObject(Nullable): """ Base class for primitive data objects. @@ -65,12 +64,16 @@ def parse_not_null(cls, stream): stream.seek(ctypes.sizeof(data_type), SEEK_CUR) return data_type - @staticmethod - def to_python(ctype_object, *args, **kwargs): + @classmethod + def to_python(cls, ctype_object, *args, **kwargs): return getattr(ctype_object, "value", None) @classmethod - def from_python_not_null(cls, stream, value): + async def to_python_async(cls, ctype_object, *args, **kwargs): + return cls.to_python(ctype_object, *args, **kwargs) + + @classmethod + def from_python_not_null(cls, stream, value, **kwargs): data_type = cls.build_c_type() data_object = data_type() data_object.type_code = int.from_bytes( @@ -89,8 +92,8 @@ class ByteObject(DataObject): pythonic = int default = 0 - @staticmethod - def hashcode(value: int, *args, **kwargs) -> int: + @classmethod + def hashcode(cls, value: int, *args, **kwargs) -> int: return value @@ -102,8 +105,8 @@ class ShortObject(DataObject): pythonic = int default = 0 - @staticmethod - def hashcode(value: int, *args, **kwargs) -> int: + @classmethod + def hashcode(cls, value: int, *args, **kwargs) -> int: return value @@ -115,8 +118,8 @@ class IntObject(DataObject): pythonic = int default = 0 - @staticmethod - def hashcode(value: int, *args, **kwargs) -> int: + @classmethod + def hashcode(cls, value: int, *args, **kwargs) -> int: return value @@ -128,8 +131,8 @@ class LongObject(DataObject): pythonic = int default = 0 - @staticmethod - def hashcode(value: int, *args, **kwargs) -> int: + @classmethod + def hashcode(cls, value: int, *args, **kwargs) -> int: return value ^ (unsigned(value, ctypes.c_ulonglong) >> 32) @@ -141,8 +144,8 @@ class FloatObject(DataObject): pythonic = float default = 0.0 - @staticmethod - def hashcode(value: float, *args, **kwargs) -> int: + @classmethod + def hashcode(cls, value: float, *args, **kwargs) -> int: return ctypes.cast( ctypes.pointer(ctypes.c_float(value)), ctypes.POINTER(ctypes.c_int) @@ -157,8 +160,8 @@ class DoubleObject(DataObject): pythonic = float default = 0.0 - @staticmethod - def hashcode(value: float, *args, **kwargs) -> int: + @classmethod + def hashcode(cls, value: float, *args, **kwargs) -> int: bits = ctypes.cast( ctypes.pointer(ctypes.c_double(value)), ctypes.POINTER(ctypes.c_longlong) @@ -180,8 +183,8 @@ class CharObject(DataObject): pythonic = str default = ' ' - @staticmethod - def hashcode(value: str, *args, **kwargs) -> int: + @classmethod + def hashcode(cls, value: str, *args, **kwargs) -> int: return ord(value) @classmethod @@ -195,7 +198,7 @@ def to_python(cls, ctype_object, *args, **kwargs): ).decode(PROTOCOL_CHAR_ENCODING) @classmethod - def from_python_not_null(cls, stream, value): + def from_python_not_null(cls, stream, value, **kwargs): if type(value) is str: value = value.encode(PROTOCOL_CHAR_ENCODING) # assuming either a bytes or an integer @@ -216,8 +219,8 @@ class BoolObject(DataObject): pythonic = bool default = False - @staticmethod - def hashcode(value: bool, *args, **kwargs) -> int: + @classmethod + def hashcode(cls, value: bool, *args, **kwargs) -> int: return 1231 if value else 1237 @classmethod @@ -226,4 +229,3 @@ def to_python(cls, ctype_object, *args, **kwargs): if value is None: return None return value != 0 - diff --git a/pyignite/datatypes/standard.py b/pyignite/datatypes/standard.py index af50a8e..2b61235 100644 --- a/pyignite/datatypes/standard.py +++ b/pyignite/datatypes/standard.py @@ -18,16 +18,15 @@ import decimal from io import SEEK_CUR from math import ceil -from typing import Any, Tuple +from typing import Tuple import uuid from pyignite.constants import * from pyignite.utils import datetime_hashcode, decimal_hashcode, hashcode -from .base import IgniteDataType from .type_codes import * from .type_ids import * from .type_names import * -from .null_object import Null, Nullable +from .null_object import Nullable __all__ = [ 'String', 'DecimalObject', 'UUIDObject', 'TimestampObject', 'DateObject', @@ -44,7 +43,7 @@ ] -class StandardObject(IgniteDataType, Nullable): +class StandardObject(Nullable): _type_name = None _type_id = None type_code = None @@ -60,7 +59,7 @@ def parse_not_null(cls, stream): return data_type -class String(IgniteDataType, Nullable): +class String(Nullable): """ Pascal-style string: `c_int` counter, followed by count*bytes. UTF-8-encoded, so that one character may take 1 to 4 bytes. @@ -70,8 +69,8 @@ class String(IgniteDataType, Nullable): type_code = TC_STRING pythonic = str - @staticmethod - def hashcode(value: str, *args, **kwargs) -> int: + @classmethod + def hashcode(cls, value: str, *args, **kwargs) -> int: return hashcode(value) @classmethod @@ -124,15 +123,15 @@ def from_python_not_null(cls, stream, value): stream.write(data_object) -class DecimalObject(IgniteDataType, Nullable): +class DecimalObject(Nullable): _type_name = NAME_DECIMAL _type_id = TYPE_DECIMAL type_code = TC_DECIMAL pythonic = decimal.Decimal default = decimal.Decimal('0.00') - @staticmethod - def hashcode(value: decimal.Decimal, *args, **kwargs) -> int: + @classmethod + def hashcode(cls, value: decimal.Decimal, *args, **kwargs) -> int: return decimal_hashcode(value) @classmethod @@ -180,11 +179,7 @@ def to_python_not_null(cls, ctype_object, *args, **kwargs): range(len(data)) ]) # apply scale - result = ( - result - / decimal.Decimal('10') - ** decimal.Decimal(ctype_object.scale) - ) + result = result / decimal.Decimal('10') ** decimal.Decimal(ctype_object.scale) if sign: # apply sign result = -result @@ -195,7 +190,7 @@ def from_python_not_null(cls, stream, value: decimal.Decimal): sign, digits, scale = value.normalize().as_tuple() integer = int(''.join([str(d) for d in digits])) # calculate number of bytes (at least one, and not forget the sign bit) - length = ceil((integer.bit_length() + 1)/8) + length = ceil((integer.bit_length() + 1) / 8) # write byte string data = [] for i in range(length): @@ -247,8 +242,8 @@ class UUIDObject(StandardObject): UUID_BYTE_ORDER = (7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8) - @staticmethod - def hashcode(value: 'UUID', *args, **kwargs) -> int: + @classmethod + def hashcode(cls, value: 'UUID', *args, **kwargs) -> int: msb = value.int >> 64 lsb = value.int & 0xffffffffffffffff hilo = msb ^ lsb @@ -309,8 +304,8 @@ class TimestampObject(StandardObject): pythonic = tuple default = (datetime(1970, 1, 1), 0) - @staticmethod - def hashcode(value: Tuple[datetime, int], *args, **kwargs) -> int: + @classmethod + def hashcode(cls, value: Tuple[datetime, int], *args, **kwargs) -> int: return datetime_hashcode(int(value[0].timestamp() * 1000)) @classmethod @@ -331,7 +326,7 @@ def build_c_type(cls): return cls._object_c_type @classmethod - def from_python_not_null(cls, stream, value: tuple): + def from_python_not_null(cls, stream, value: tuple, **kwargs): data_type = cls.build_c_type() data_object = data_type() data_object.type_code = int.from_bytes( @@ -346,7 +341,7 @@ def from_python_not_null(cls, stream, value: tuple): @classmethod def to_python_not_null(cls, ctypes_object, *args, **kwargs): return ( - datetime.fromtimestamp(ctypes_object.epoch/1000), + datetime.fromtimestamp(ctypes_object.epoch / 1000), ctypes_object.fraction ) @@ -365,8 +360,8 @@ class DateObject(StandardObject): pythonic = datetime default = datetime(1970, 1, 1) - @staticmethod - def hashcode(value: datetime, *args, **kwargs) -> int: + @classmethod + def hashcode(cls, value: datetime, *args, **kwargs) -> int: return datetime_hashcode(int(value.timestamp() * 1000)) @classmethod @@ -401,7 +396,7 @@ def from_python_not_null(cls, stream, value: [date, datetime]): @classmethod def to_python_not_null(cls, ctypes_object, *args, **kwargs): - return datetime.fromtimestamp(ctypes_object.epoch/1000) + return datetime.fromtimestamp(ctypes_object.epoch / 1000) class TimeObject(StandardObject): @@ -417,8 +412,8 @@ class TimeObject(StandardObject): pythonic = timedelta default = timedelta() - @staticmethod - def hashcode(value: timedelta, *args, **kwargs) -> int: + @classmethod + def hashcode(cls, value: timedelta, *args, **kwargs) -> int: return datetime_hashcode(int(value.total_seconds() * 1000)) @classmethod @@ -510,7 +505,7 @@ class BinaryEnumObject(EnumObject): type_code = TC_BINARY_ENUM -class StandardArray(IgniteDataType, Nullable): +class StandardArray(Nullable): """ Base class for array of primitives. Payload-only. """ @@ -519,15 +514,10 @@ class StandardArray(IgniteDataType, Nullable): standard_type = None type_code = None - @staticmethod - def hashcode(value: Any) -> int: - # Arrays are not supported as keys at the moment. - return 0 - @classmethod def build_header_class(cls): return type( - cls.__name__+'Header', + cls.__name__ + 'Header', (ctypes.LittleEndianStructure,), { '_pack_': 1, @@ -575,7 +565,11 @@ def to_python(cls, ctype_object, *args, **kwargs): return result @classmethod - def from_python_not_null(cls, stream, value): + async def to_python_async(cls, ctypes_object, *args, **kwargs): + return cls.to_python(ctypes_object, *args, **kwargs) + + @classmethod + def from_python_not_null(cls, stream, value, **kwargs): header_class = cls.build_header_class() header = header_class() if hasattr(header, 'type_code'): @@ -648,7 +642,7 @@ class StandardArrayObject(StandardArray): @classmethod def build_header_class(cls): return type( - cls.__name__+'Header', + cls.__name__ + 'Header', (ctypes.LittleEndianStructure,), { '_pack_': 1, @@ -723,7 +717,7 @@ class EnumArrayObject(StandardArrayObject): @classmethod def build_header_class(cls): return type( - cls.__name__+'Header', + cls.__name__ + 'Header', (ctypes.LittleEndianStructure,), { '_pack_': 1, @@ -736,7 +730,7 @@ def build_header_class(cls): ) @classmethod - def from_python_not_null(cls, stream, value): + def from_python_not_null(cls, stream, value, **kwargs): type_id, value = value header_class = cls.build_header_class() header = header_class() @@ -754,7 +748,7 @@ def from_python_not_null(cls, stream, value): cls.standard_type.from_python(stream, x) @classmethod - def to_python(cls, ctype_object, *args, **kwargs): + def to_python_not_null(cls, ctype_object, *args, **kwargs): type_id = getattr(ctype_object, "type_id", None) if type_id is None: return None diff --git a/pyignite/exceptions.py b/pyignite/exceptions.py index 5933228..579aa29 100644 --- a/pyignite/exceptions.py +++ b/pyignite/exceptions.py @@ -93,4 +93,4 @@ class SQLError(CacheError): pass -connection_errors = (IOError, OSError) +connection_errors = (IOError, OSError, EOFError) diff --git a/pyignite/queries/__init__.py b/pyignite/queries/__init__.py index d558125..56c6347 100644 --- a/pyignite/queries/__init__.py +++ b/pyignite/queries/__init__.py @@ -21,4 +21,4 @@ :mod:`pyignite.datatypes` binary parser/generator classes. """ -from .query import Query, ConfigQuery +from .query import Query, ConfigQuery, query_perform diff --git a/pyignite/queries/query.py b/pyignite/queries/query.py index b5be753..beea5d9 100644 --- a/pyignite/queries/query.py +++ b/pyignite/queries/query.py @@ -13,15 +13,35 @@ # See the License for the specific language governing permissions and # limitations under the License. -import attr import ctypes +from io import SEEK_CUR from random import randint +import attr + from pyignite.api.result import APIResult -from pyignite.connection import Connection +from pyignite.connection import Connection, AioConnection from pyignite.constants import MIN_LONG, MAX_LONG, RHF_TOPOLOGY_CHANGED -from pyignite.queries.response import Response, SQLResponse -from pyignite.stream import BinaryStream, READ_BACKWARD +from pyignite.queries.response import Response +from pyignite.stream import AioBinaryStream, BinaryStream, READ_BACKWARD + + +def query_perform(query_struct, conn, post_process_fun=None, **kwargs): + async def _async_internal(): + result = await query_struct.perform_async(conn, **kwargs) + if post_process_fun: + return post_process_fun(result) + return result + + def _internal(): + result = query_struct.perform(conn, **kwargs) + if post_process_fun: + return post_process_fun(result) + return result + + if isinstance(conn, AioConnection): + return _async_internal() + return _internal() @attr.s @@ -29,6 +49,7 @@ class Query: op_code = attr.ib(type=int) following = attr.ib(type=list, factory=list) query_id = attr.ib(type=int, default=None) + response_type = attr.ib(type=type(Response), default=Response) _query_c_type = None @classmethod @@ -48,32 +69,45 @@ def build_c_type(cls): ) return cls._query_c_type - def _build_header(self, stream, values: dict): + def from_python(self, stream, values: dict = None): + init_pos, header = stream.tell(), self._build_header(stream) + values = values if values else None + + for name, c_type in self.following: + c_type.from_python(stream, values[name]) + + self.__write_header(stream, header, init_pos) + + async def from_python_async(self, stream, values: dict = None): + init_pos, header = stream.tell(), self._build_header(stream) + values = values if values else None + + for name, c_type in self.following: + await c_type.from_python_async(stream, values[name]) + + self.__write_header(stream, header, init_pos) + + def _build_header(self, stream): header_class = self.build_c_type() header_len = ctypes.sizeof(header_class) - init_pos = stream.tell() - stream.seek(init_pos + header_len) + stream.seek(header_len, SEEK_CUR) header = header_class() header.op_code = self.op_code if self.query_id is None: header.query_id = randint(MIN_LONG, MAX_LONG) - for name, c_type in self.following: - c_type.from_python(stream, values[name]) + return header + @staticmethod + def __write_header(stream, header, init_pos): header.length = stream.tell() - init_pos - ctypes.sizeof(ctypes.c_int) stream.seek(init_pos) - - return header - - def from_python(self, stream, values: dict = None): - header = self._build_header(stream, values if values else {}) stream.write(header) def perform( self, conn: Connection, query_params: dict = None, - response_config: list = None, sql: bool = False, **kwargs, + response_config: list = None, **kwargs, ) -> APIResult: """ Perform query and process result. @@ -83,26 +117,60 @@ def perform( Defaults to no parameters, :param response_config: (optional) response configuration − list of (name, type_hint) tuples. Defaults to empty return value, - :param sql: (optional) use normal (default) or SQL response class, :return: instance of :class:`~pyignite.api.result.APIResult` with raw value (may undergo further processing in API functions). """ - with BinaryStream(conn) as stream: + with BinaryStream(conn.client) as stream: self.from_python(stream, query_params) - conn.send(stream.getbuffer()) + response_data = conn.request(stream.getbuffer()) - if sql: - response_struct = SQLResponse(protocol_version=conn.get_protocol_version(), - following=response_config, **kwargs) - else: - response_struct = Response(protocol_version=conn.get_protocol_version(), - following=response_config) + response_struct = self.response_type(protocol_version=conn.protocol_version, + following=response_config, **kwargs) - with BinaryStream(conn, conn.recv()) as stream: + with BinaryStream(conn.client, response_data) as stream: response_ctype = response_struct.parse(stream) response = stream.read_ctype(response_ctype, direction=READ_BACKWARD) - # this test depends on protocol version + result = self.__post_process_response(conn, response_struct, response) + + if result.status == 0: + result.value = response_struct.to_python(response) + return result + + async def perform_async( + self, conn: AioConnection, query_params: dict = None, + response_config: list = None, **kwargs, + ) -> APIResult: + """ + Perform query and process result. + + :param conn: connection to Ignite server, + :param query_params: (optional) dict of named query parameters. + Defaults to no parameters, + :param response_config: (optional) response configuration − list of + (name, type_hint) tuples. Defaults to empty return value, + :return: instance of :class:`~pyignite.api.result.APIResult` with raw + value (may undergo further processing in API functions). + """ + with AioBinaryStream(conn.client) as stream: + await self.from_python_async(stream, query_params) + data = await conn.request(stream.getbuffer()) + + response_struct = self.response_type(protocol_version=conn.protocol_version, + following=response_config, **kwargs) + + with AioBinaryStream(conn.client, data) as stream: + response_ctype = await response_struct.parse_async(stream) + response = stream.read_ctype(response_ctype, direction=READ_BACKWARD) + + result = self.__post_process_response(conn, response_struct, response) + + if result.status == 0: + result.value = await response_struct.to_python_async(response) + return result + + @staticmethod + def __post_process_response(conn, response_struct, response): if getattr(response, 'flags', False) & RHF_TOPOLOGY_CHANGED: # update latest affinity version new_affinity = (response.affinity_version, response.affinity_minor) @@ -112,10 +180,7 @@ def perform( conn.client.affinity_version = new_affinity # build result - result = APIResult(response) - if result.status == 0: - result.value = response_struct.to_python(response) - return result + return APIResult(response) class ConfigQuery(Query): @@ -142,7 +207,7 @@ def build_c_type(cls): ) return cls._query_c_type - def _build_header(self, stream, values: dict): - header = super()._build_header(stream, values) + def _build_header(self, stream): + header = super()._build_header(stream) header.config_length = header.length - ctypes.sizeof(type(header)) return header diff --git a/pyignite/queries/response.py b/pyignite/queries/response.py index ca2ae14..83a6e6a 100644 --- a/pyignite/queries/response.py +++ b/pyignite/queries/response.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio from io import SEEK_CUR import attr @@ -20,6 +21,7 @@ from pyignite.constants import RHF_TOPOLOGY_CHANGED, RHF_ERROR from pyignite.datatypes import AnyDataObject, Bool, Int, Long, String, StringArray, Struct +from pyignite.datatypes.binary import body_struct, enum_struct, schema_struct from pyignite.queries.op_codes import OP_SUCCESS from pyignite.stream import READ_BACKWARD @@ -35,7 +37,7 @@ def __attrs_post_init__(self): # replace None with empty list self.following = self.following or [] - def build_header(self): + def __build_header(self): if self._response_header is None: fields = [ ('length', ctypes.c_int), @@ -57,9 +59,9 @@ def build_header(self): ) return self._response_header - def parse(self, stream): + def __parse_header(self, stream): init_pos = stream.tell() - header_class = self.build_header() + header_class = self.__build_header() header_len = ctypes.sizeof(header_class) header = stream.read_ctype(header_class) stream.seek(header_len, SEEK_CUR) @@ -85,9 +87,10 @@ def parse(self, stream): if has_error: msg_type = String.parse(stream) fields.append(('error_message', msg_type)) - else: - self._parse_success(stream, fields) + return not has_error, init_pos, header_class, fields + + def __build_response_class(self, stream, init_pos, header_class, fields): response_class = type( self._response_class_name, (header_class,), @@ -100,21 +103,52 @@ def parse(self, stream): stream.seek(init_pos + ctypes.sizeof(response_class)) return response_class + def parse(self, stream): + success, init_pos, header_class, fields = self.__parse_header(stream) + if success: + self._parse_success(stream, fields) + + return self.__build_response_class(stream, init_pos, header_class, fields) + + async def parse_async(self, stream): + success, init_pos, header_class, fields = self.__parse_header(stream) + if success: + await self._parse_success_async(stream, fields) + + return self.__build_response_class(stream, init_pos, header_class, fields) + def _parse_success(self, stream, fields: list): for name, ignite_type in self.following: c_type = ignite_type.parse(stream) fields.append((name, c_type)) + async def _parse_success_async(self, stream, fields: list): + for name, ignite_type in self.following: + c_type = await ignite_type.parse_async(stream) + fields.append((name, c_type)) + def to_python(self, ctype_object, *args, **kwargs): - result = OrderedDict() + if not self.following: + return None + result = OrderedDict() for name, c_type in self.following: result[name] = c_type.to_python( getattr(ctype_object, name), *args, **kwargs ) - return result if result else None + return result + + async def to_python_async(self, ctype_object, *args, **kwargs): + if not self.following: + return None + + values = await asyncio.gather( + *[c_type.to_python_async(getattr(ctype_object, name), *args, **kwargs) for name, c_type in self.following] + ) + + return OrderedDict([(name, values[i]) for i, (name, _) in enumerate(self.following)]) @attr.s @@ -135,38 +169,62 @@ def fields_or_field_count(self): return 'field_count', Int def _parse_success(self, stream, fields: list): - following = [ - self.fields_or_field_count(), - ('row_count', Int), - ] - if self.has_cursor: - following.insert(0, ('cursor', Long)) - body_struct = Struct(following) + body_struct = self.__create_body_struct() body_class = body_struct.parse(stream) body = stream.read_ctype(body_class, direction=READ_BACKWARD) - if self.include_field_names: - field_count = body.fields.length - else: - field_count = body.field_count - - data_fields = [] + data_fields, field_count = [], self.__get_fields_count(body) for i in range(body.row_count): row_fields = [] for j in range(field_count): field_class = AnyDataObject.parse(stream) row_fields.append(('column_{}'.format(j), field_class)) - row_class = type( - 'SQLResponseRow', - (ctypes.LittleEndianStructure,), - { - '_pack_': 1, - '_fields_': row_fields, - } - ) - data_fields.append(('row_{}'.format(i), row_class)) + self.__row_post_process(i, row_fields, data_fields) + + self.__body_class_post_process(body_class, fields, data_fields) + + async def _parse_success_async(self, stream, fields: list): + body_struct = self.__create_body_struct() + body_class = await body_struct.parse_async(stream) + body = stream.read_ctype(body_class, direction=READ_BACKWARD) + + data_fields, field_count = [], self.__get_fields_count(body) + for i in range(body.row_count): + row_fields = [] + for j in range(field_count): + field_class = await AnyDataObject.parse_async(stream) + row_fields.append(('column_{}'.format(j), field_class)) + + self.__row_post_process(i, row_fields, data_fields) + + self.__body_class_post_process(body_class, fields, data_fields) + + def __create_body_struct(self): + following = [self.fields_or_field_count(), ('row_count', Int)] + if self.has_cursor: + following.insert(0, ('cursor', Long)) + return Struct(following) + + def __get_fields_count(self, body): + if self.include_field_names: + return body.fields.length + return body.field_count + + @staticmethod + def __row_post_process(idx, row_fields, data_fields): + row_class = type( + 'SQLResponseRow', + (ctypes.LittleEndianStructure,), + { + '_pack_': 1, + '_fields_': row_fields, + } + ) + data_fields.append((f'row_{idx}', row_class)) + @staticmethod + def __body_class_post_process(body_class, fields, data_fields): data_class = type( 'SQLResponseData', (ctypes.LittleEndianStructure,), @@ -182,24 +240,8 @@ def _parse_success(self, stream, fields: list): def to_python(self, ctype_object, *args, **kwargs): if getattr(ctype_object, 'status_code', 0) == 0: - result = { - 'more': Bool.to_python( - ctype_object.more, *args, **kwargs - ), - 'data': [], - } - if hasattr(ctype_object, 'fields'): - result['fields'] = StringArray.to_python( - ctype_object.fields, *args, **kwargs - ) - else: - result['field_count'] = Int.to_python( - ctype_object.field_count, *args, **kwargs - ) - if hasattr(ctype_object, 'cursor'): - result['cursor'] = Long.to_python( - ctype_object.cursor, *args, **kwargs - ) + result = self.__to_python_result_header(ctype_object, *args, **kwargs) + for row_item in ctype_object.data._fields_: row_name = row_item[0] row_object = getattr(ctype_object.data, row_name) @@ -207,8 +249,104 @@ def to_python(self, ctype_object, *args, **kwargs): for col_item in row_object._fields_: col_name = col_item[0] col_object = getattr(row_object, col_name) - row.append( - AnyDataObject.to_python(col_object, *args, **kwargs) - ) + row.append(AnyDataObject.to_python(col_object, *args, **kwargs)) result['data'].append(row) return result + + async def to_python_async(self, ctype_object, *args, **kwargs): + if getattr(ctype_object, 'status_code', 0) == 0: + result = self.__to_python_result_header(ctype_object, *args, **kwargs) + + data_coro = [] + for row_item in ctype_object.data._fields_: + row_name = row_item[0] + row_object = getattr(ctype_object.data, row_name) + row_coro = [] + for col_item in row_object._fields_: + col_name = col_item[0] + col_object = getattr(row_object, col_name) + row_coro.append(AnyDataObject.to_python_async(col_object, *args, **kwargs)) + + data_coro.append(asyncio.gather(*row_coro)) + + result['data'] = await asyncio.gather(*data_coro) + return result + + @staticmethod + def __to_python_result_header(ctype_object, *args, **kwargs): + result = { + 'more': Bool.to_python(ctype_object.more, *args, **kwargs), + 'data': [], + } + if hasattr(ctype_object, 'fields'): + result['fields'] = StringArray.to_python(ctype_object.fields, *args, **kwargs) + else: + result['field_count'] = Int.to_python(ctype_object.field_count, *args, **kwargs) + + if hasattr(ctype_object, 'cursor'): + result['cursor'] = Long.to_python(ctype_object.cursor, *args, **kwargs) + return result + + +class BinaryTypeResponse(Response): + _response_class_name = 'GetBinaryTypeResponse' + + def _parse_success(self, stream, fields: list): + type_exists = self.__process_type_exists(stream, fields) + + if type_exists.value: + resp_body_type = body_struct.parse(stream) + fields.append(('body', resp_body_type)) + resp_body = stream.read_ctype(resp_body_type, direction=READ_BACKWARD) + if resp_body.is_enum: + resp_enum = enum_struct.parse(stream) + fields.append(('enums', resp_enum)) + + resp_schema_type = schema_struct.parse(stream) + fields.append(('schema', resp_schema_type)) + + async def _parse_success_async(self, stream, fields: list): + type_exists = self.__process_type_exists(stream, fields) + + if type_exists.value: + resp_body_type = await body_struct.parse_async(stream) + fields.append(('body', resp_body_type)) + resp_body = stream.read_ctype(resp_body_type, direction=READ_BACKWARD) + if resp_body.is_enum: + resp_enum = await enum_struct.parse_async(stream) + fields.append(('enums', resp_enum)) + + resp_schema_type = await schema_struct.parse_async(stream) + fields.append(('schema', resp_schema_type)) + + @staticmethod + def __process_type_exists(stream, fields): + fields.append(('type_exists', ctypes.c_byte)) + type_exists = stream.read_ctype(ctypes.c_byte) + stream.seek(ctypes.sizeof(ctypes.c_byte), SEEK_CUR) + + return type_exists + + def to_python(self, ctype_object, *args, **kwargs): + if getattr(ctype_object, 'status_code', 0) == 0: + result = { + 'type_exists': Bool.to_python(ctype_object.type_exists) + } + + if hasattr(ctype_object, 'body'): + result.update(body_struct.to_python(ctype_object.body)) + + if hasattr(ctype_object, 'enums'): + result['enums'] = enum_struct.to_python(ctype_object.enums) + + if hasattr(ctype_object, 'schema'): + result['schema'] = { + x['schema_id']: [ + z['schema_field_id'] for z in x['schema_fields'] + ] + for x in schema_struct.to_python(ctype_object.schema) + } + return result + + async def to_python_async(self, ctype_object, *args, **kwargs): + return self.to_python(ctype_object, *args, **kwargs) diff --git a/pyignite/stream/__init__.py b/pyignite/stream/__init__.py index 94153b4..76d171d 100644 --- a/pyignite/stream/__init__.py +++ b/pyignite/stream/__init__.py @@ -13,4 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .binary_stream import BinaryStream, READ_FORWARD, READ_BACKWARD \ No newline at end of file +from .binary_stream import BinaryStream, AioBinaryStream, READ_FORWARD, READ_BACKWARD + +__all__ = ['BinaryStream', 'AioBinaryStream', 'READ_BACKWARD', 'READ_FORWARD'] diff --git a/pyignite/stream/binary_stream.py b/pyignite/stream/binary_stream.py index 46ac683..57b4b83 100644 --- a/pyignite/stream/binary_stream.py +++ b/pyignite/stream/binary_stream.py @@ -14,39 +14,23 @@ # limitations under the License. import ctypes from io import BytesIO +from typing import Union, Optional +import pyignite import pyignite.utils as ignite_utils READ_FORWARD = 0 READ_BACKWARD = 1 -class BinaryStream: - def __init__(self, conn, buf=None): - """ - Initialize binary stream around buffers. - - :param buf: Buffer, optional parameter. If not passed, creates empty BytesIO. - :param conn: Connection instance, required. - """ - from pyignite.connection import Connection - - if not isinstance(conn, Connection): - raise TypeError(f"invalid parameter: expected instance of {Connection}") - - if buf and not isinstance(buf, (bytearray, bytes, memoryview)): - raise TypeError(f"invalid parameter: expected bytes-like object") - - self.conn = conn - self.stream = BytesIO(buf) if buf else BytesIO() - +class BinaryStreamBaseMixin: @property def compact_footer(self) -> bool: - return self.conn.client.compact_footer + return self.client.compact_footer @compact_footer.setter def compact_footer(self, value: bool): - self.conn.client.compact_footer = value + self.client.compact_footer = value def read(self, size): buf = bytearray(size) @@ -86,10 +70,10 @@ def getbuffer(self): def mem_view(self, start=-1, offset=0): start = start if start >= 0 else self.tell() - return self.stream.getbuffer()[start:start+offset] + return self.stream.getbuffer()[start:start + offset] def hashcode(self, start, bytes_len): - return ignite_utils.hashcode(self.stream.getbuffer()[start:start+bytes_len]) + return ignite_utils.hashcode(self.stream.getbuffer()[start:start + bytes_len]) def __enter__(self): return self @@ -100,15 +84,48 @@ def __exit__(self, exc_type, exc_value, traceback): except BufferError: pass + +class BinaryStream(BinaryStreamBaseMixin): + """ + Synchronous binary stream. + """ + def __init__(self, client: 'pyignite.Client', buf: Optional[Union[bytes, bytearray, memoryview]] = None): + """ + :param client: Client instance, required. + :param buf: Buffer, optional parameter. If not passed, creates empty BytesIO. + """ + self.client = client + self.stream = BytesIO(buf) if buf else BytesIO() + def get_dataclass(self, header): - # get field names from outer space - result = self.conn.client.query_binary_type( - header.type_id, - header.schema_id - ) + result = self.client.query_binary_type(header.type_id, header.schema_id) if not result: raise RuntimeError('Binary type is not registered') return result def register_binary_type(self, *args, **kwargs): - return self.conn.client.register_binary_type(*args, **kwargs) + self.client.register_binary_type(*args, **kwargs) + + +class AioBinaryStream(BinaryStreamBaseMixin): + """ + Asyncio binary stream. + """ + def __init__(self, client: 'pyignite.AioClient', buf: Optional[Union[bytes, bytearray, memoryview]] = None): + """ + Initialize binary stream around buffers. + + :param client: AioClient instance, required. + :param buf: Buffer, optional parameter. If not passed, creates empty BytesIO. + """ + self.client = client + self.stream = BytesIO(buf) if buf else BytesIO() + + async def get_dataclass(self, header): + result = await self.client.query_binary_type(header.type_id, header.schema_id) + if not result: + raise RuntimeError('Binary type is not registered') + return result + + async def register_binary_type(self, *args, **kwargs): + await self.client.register_binary_type(*args, **kwargs) diff --git a/pyignite/utils.py b/pyignite/utils.py index f1a7f90..975f414 100644 --- a/pyignite/utils.py +++ b/pyignite/utils.py @@ -15,6 +15,7 @@ import ctypes import decimal +import inspect import warnings from functools import wraps @@ -65,23 +66,14 @@ def is_hinted(value): """ Check if a value is a tuple of data item and its type hint. """ - return ( - isinstance(value, tuple) - and len(value) == 2 - and issubclass(value[1], IgniteDataType) - ) + return isinstance(value, tuple) and len(value) == 2 and issubclass(value[1], IgniteDataType) def is_wrapped(value: Any) -> bool: """ Check if a value is of WrappedDataObject type. """ - return ( - type(value) is tuple - and len(value) == 2 - and type(value[0]) is bytes - and type(value[1]) is int - ) + return type(value) is tuple and len(value) == 2 and type(value[0]) is bytes and type(value[1]) is int def int_overflow(value: int) -> int: @@ -107,7 +99,7 @@ def hashcode(data: Union[str, bytes, bytearray, memoryview]) -> int: def __hashcode_fallback(data: Union[str, bytes, bytearray, memoryview]) -> int: if data is None: return 0 - + if isinstance(data, str): """ For strings we iterate over code point which are of the int type @@ -206,8 +198,7 @@ def decimal_hashcode(value: decimal.Decimal) -> int: # this is the case when Java BigDecimal digits are stored # compactly, in the internal 64-bit integer field int_hash = ( - (unsigned(value, ctypes.c_ulonglong) >> 32) * 31 - + (value & LONG_MASK) + (unsigned(value, ctypes.c_ulonglong) >> 32) * 31 + (value & LONG_MASK) ) & LONG_MASK else: # digits are not fit in the 64-bit long, so they get split internally @@ -243,25 +234,31 @@ def datetime_hashcode(value: int) -> int: def status_to_exception(exc: Type[Exception]): """ Converts erroneous status code with error message to an exception - of the given class. + of the given class. Supports coroutines. :param exc: the class of exception to raise, - :return: decorator. + :return: decorated function. """ + def process_result(result): + if result.status != 0: + raise exc(result.message) + return result.value + def ste_decorator(fn): - @wraps(fn) - def ste_wrapper(*args, **kwargs): - result = fn(*args, **kwargs) - if result.status != 0: - raise exc(result.message) - return result.value - return ste_wrapper + if inspect.iscoroutinefunction(fn): + @wraps(fn) + async def ste_wrapper_async(*args, **kwargs): + return process_result(await fn(*args, **kwargs)) + return ste_wrapper_async + else: + @wraps(fn) + def ste_wrapper(*args, **kwargs): + return process_result(fn(*args, **kwargs)) + return ste_wrapper return ste_decorator -def get_field_by_id( - obj: 'GenericObjectMeta', field_id: int -) -> Tuple[Any, IgniteDataType]: +def get_field_by_id(obj: 'GenericObjectMeta', field_id: int) -> Tuple[Any, IgniteDataType]: """ Returns a complex object's field value, given the field's entity ID. diff --git a/requirements/tests.txt b/requirements/tests.txt index 5d5ae84..38a8e9e 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -1,7 +1,10 @@ # these packages are used for testing +async_generator==1.10; python_version < '3.7' pytest==6.2.2 pytest-cov==2.11.1 +pytest-asyncio==0.14.0 teamcity-messages==1.28 psutil==5.8.0 jinja2==2.11.3 +flake8==3.8.4 diff --git a/setup.py b/setup.py index 4d90e4e..5db3aed 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import re from collections import defaultdict from distutils.command.build_ext import build_ext from distutils.errors import CCompilerError, DistutilsExecError, DistutilsPlatformError @@ -86,6 +86,14 @@ def is_a_requirement(line): with open('README.md', 'r', encoding='utf-8') as readme_file: long_description = readme_file.read() +version = '' +with open('pyignite/__init__.py', 'r') as fd: + version = re.search(r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]', + fd.read(), re.MULTILINE).group(1) + +if not version: + raise RuntimeError('Cannot find version information') + def run_setup(with_binary=True): if with_binary: @@ -98,7 +106,7 @@ def run_setup(with_binary=True): setuptools.setup( name='pyignite', - version='0.4.0', + version=version, python_requires='>=3.6', author='The Apache Software Foundation', author_email='dev@ignite.apache.org', diff --git a/tests/affinity/conftest.py b/tests/affinity/conftest.py index 7595f25..2ec2b1b 100644 --- a/tests/affinity/conftest.py +++ b/tests/affinity/conftest.py @@ -15,8 +15,7 @@ import pytest -from pyignite import Client -from pyignite.api import cache_create, cache_destroy +from pyignite import Client, AioClient from tests.util import start_ignite_gen # Sometimes on slow testing servers and unstable topology @@ -42,29 +41,21 @@ def server3(): @pytest.fixture def client(): client = Client(partition_aware=True, timeout=CLIENT_SOCKET_TIMEOUT) - - client.connect([('127.0.0.1', 10800 + i) for i in range(1, 4)]) - - yield client - - client.close() - - -@pytest.fixture -def client_not_connected(): - client = Client(partition_aware=True, timeout=CLIENT_SOCKET_TIMEOUT) - yield client - client.close() + try: + client.connect([('127.0.0.1', 10800 + i) for i in range(1, 4)]) + yield client + finally: + client.close() @pytest.fixture -def cache(connected_client): - cache_name = 'my_bucket' - conn = connected_client.random_node - - cache_create(conn, cache_name) - yield cache_name - cache_destroy(conn, cache_name) +async def async_client(): + client = AioClient(partition_aware=True) + try: + await client.connect([('127.0.0.1', 10800 + i) for i in range(1, 4)]) + yield client + finally: + await client.close() @pytest.fixture(scope='module', autouse=True) diff --git a/tests/affinity/test_affinity.py b/tests/affinity/test_affinity.py index ee8f6c0..b1bcec7 100644 --- a/tests/affinity/test_affinity.py +++ b/tests/affinity/test_affinity.py @@ -13,178 +13,265 @@ # See the License for the specific language governing permissions and # limitations under the License. -from datetime import datetime, timedelta +import asyncio import decimal +from datetime import datetime, timedelta from uuid import UUID, uuid4 import pytest -from pyignite import GenericObjectMeta -from pyignite.api import * -from pyignite.constants import * -from pyignite.datatypes import * +from pyignite import GenericObjectMeta, AioClient +from pyignite.api import ( + cache_get_node_partitions, cache_get_node_partitions_async, cache_local_peek, cache_local_peek_async +) +from pyignite.constants import MAX_INT +from pyignite.datatypes import ( + BinaryObject, ByteArray, ByteObject, IntObject, ShortObject, LongObject, FloatObject, DoubleObject, BoolObject, + CharObject, String, UUIDObject, DecimalObject, TimestampObject, TimeObject +) from pyignite.datatypes.cache_config import CacheMode -from pyignite.datatypes.prop_codes import * +from pyignite.datatypes.prop_codes import PROP_NAME, PROP_CACHE_MODE, PROP_CACHE_KEY_CONFIGURATION +from tests.util import wait_for_condition, wait_for_condition_async -def test_get_node_partitions(client): - conn = client.random_node +def test_get_node_partitions(client, caches): + cache_ids = [cache.cache_id for cache in caches] + __wait_for_ready_affinity(client, cache_ids) + mappings = __get_mappings(client, cache_ids) + __check_mappings(mappings, cache_ids) - cache_1 = client.get_or_create_cache('test_cache_1') - cache_2 = client.get_or_create_cache({ - PROP_NAME: 'test_cache_2', - PROP_CACHE_KEY_CONFIGURATION: [ - { - 'type_name': ByteArray.type_name, - 'affinity_key_field_name': 'byte_affinity', - } - ], - }) - client.get_or_create_cache('test_cache_3') - client.get_or_create_cache('test_cache_4') - client.get_or_create_cache('test_cache_5') - - result = cache_get_node_partitions( - conn, - [cache_1.cache_id, cache_2.cache_id] - ) - assert result.status == 0, result.message - - -@pytest.mark.parametrize( - 'key, key_hint', [ - # integers - (42, None), - (43, ByteObject), - (-44, ByteObject), - (45, IntObject), - (-46, IntObject), - (47, ShortObject), - (-48, ShortObject), - (49, LongObject), - (MAX_INT-50, LongObject), - (MAX_INT+51, LongObject), - - # floating point - (5.2, None), - (5.354, FloatObject), - (-5.556, FloatObject), - (-57.58, DoubleObject), - - # boolean - (True, None), - (True, BoolObject), - (False, BoolObject), - - # char - ('A', CharObject), - ('Z', CharObject), - ('⅓', CharObject), - ('á', CharObject), - ('ы', CharObject), - ('カ', CharObject), - ('Ø', CharObject), - ('ß', CharObject), - - # string - ('This is a test string', None), - ('Кириллица', None), - ('Little Mary had a lamb', String), - - # UUID - (UUID('12345678123456789876543298765432'), None), - (UUID('74274274274274274274274274274274'), UUIDObject), - (uuid4(), None), - - # decimal (long internal representation in Java) - (decimal.Decimal('-234.567'), None), - (decimal.Decimal('200.0'), None), - (decimal.Decimal('123.456'), DecimalObject), - (decimal.Decimal('1.0'), None), - (decimal.Decimal('0.02'), None), - - # decimal (BigInteger internal representation in Java) - (decimal.Decimal('12345671234567123.45671234567'), None), - (decimal.Decimal('-845678456.7845678456784567845'), None), - - # date and time - (datetime(1980, 1, 1), None), - ((datetime(1980, 1, 1), 999), TimestampObject), - (timedelta(days=99), TimeObject), - - ], -) -def test_affinity(client, key, key_hint): - cache_1 = client.get_or_create_cache({ - PROP_NAME: 'test_cache_1', - PROP_CACHE_MODE: CacheMode.PARTITIONED, - }) - value = 42 - cache_1.put(key, value, key_hint=key_hint) - best_node = cache_1.get_best_node(key, key_hint=key_hint) +@pytest.mark.asyncio +async def test_get_node_partitions_async(async_client, async_caches): + cache_ids = [cache.cache_id for cache in async_caches] + await __wait_for_ready_affinity(async_client, cache_ids) + mappings = await __get_mappings(async_client, cache_ids) + __check_mappings(mappings, cache_ids) - for node in filter(lambda n: n.alive, client._nodes): - result = cache_local_peek( - node, cache_1.cache_id, key, key_hint=key_hint, - ) - if node is best_node: - assert result.value == value, ( - 'Affinity calculation error for {}'.format(key) - ) - else: - assert result.value is None, ( - 'Affinity calculation error for {}'.format(key) - ) - cache_1.destroy() +def __wait_for_ready_affinity(client, cache_ids): + def inner(): + def condition(): + result = __get_mappings(client, cache_ids) + return len(result.value['partition_mapping']) == len(cache_ids) + wait_for_condition(condition) -def test_affinity_for_generic_object(client): - cache_1 = client.get_or_create_cache({ - PROP_NAME: 'test_cache_1', - PROP_CACHE_MODE: CacheMode.PARTITIONED, - }) + async def inner_async(): + async def condition(): + result = await __get_mappings(client, cache_ids) + return len(result.value['partition_mapping']) == len(cache_ids) - class KeyClass( - metaclass=GenericObjectMeta, - schema={ - 'NO': IntObject, - 'NAME': String, - }, - ): - pass + await wait_for_condition_async(condition) - key = KeyClass() - key.NO = 1 - key.NAME = 'test_string' + return inner_async() if isinstance(client, AioClient) else inner() - cache_1.put(key, 42, key_hint=BinaryObject) - best_node = cache_1.get_best_node(key, key_hint=BinaryObject) +def __get_mappings(client, cache_ids): + def inner(): + conn = client.random_node + result = cache_get_node_partitions(conn, cache_ids) + assert result.status == 0, result.message + return result + + async def inner_async(): + conn = await client.random_node() + result = await cache_get_node_partitions_async(conn, cache_ids) + assert result.status == 0, result.message + return result + + return inner_async() if isinstance(client, AioClient) else inner() - for node in filter(lambda n: n.alive, client._nodes): - result = cache_local_peek( - node, cache_1.cache_id, key, key_hint=BinaryObject, - ) - if node is best_node: - assert result.value == 42, ( - 'Affinity calculation error for {}'.format(key) - ) - else: - assert result.value is None, ( - 'Affinity calculation error for {}'.format(key) - ) - cache_1.destroy() +def __check_mappings(result, cache_ids): + partition_mapping = result.value['partition_mapping'] + for i, cache_id in enumerate(cache_ids): + cache_mapping = partition_mapping[cache_id] + assert 'is_applicable' in cache_mapping -def test_affinity_for_generic_object_without_type_hints(client): - cache_1 = client.get_or_create_cache({ + # Check replicated cache + if i == 3: + assert not cache_mapping['is_applicable'] + assert 'node_mapping' not in cache_mapping + assert cache_mapping['number_of_partitions'] == 0 + else: + # Check cache config + if i == 2: + assert cache_mapping['cache_config'] + + assert cache_mapping['is_applicable'] + assert cache_mapping['node_mapping'] + assert cache_mapping['number_of_partitions'] == 1024 + + +@pytest.fixture +def caches(client): + yield from __create_caches_fixture(client) + + +@pytest.fixture +async def async_caches(async_client): + async for caches in __create_caches_fixture(async_client): + yield caches + + +def __create_caches_fixture(client): + caches_to_create = [] + for i in range(0, 5): + cache_name = f'test_cache_{i}' + if i == 2: + caches_to_create.append(( + cache_name, + { + PROP_NAME: cache_name, + PROP_CACHE_KEY_CONFIGURATION: [ + { + 'type_name': ByteArray.type_name, + 'affinity_key_field_name': 'byte_affinity', + } + ] + })) + elif i == 3: + caches_to_create.append(( + cache_name, + { + PROP_NAME: cache_name, + PROP_CACHE_MODE: CacheMode.REPLICATED + } + )) + else: + caches_to_create.append((cache_name, None)) + + def generate_caches(): + caches = [] + for name, config in caches_to_create: + if config: + cache = client.get_or_create_cache(config) + else: + cache = client.get_or_create_cache(name) + caches.append(cache) + return asyncio.gather(*caches) if isinstance(client, AioClient) else caches + + def inner(): + caches = [] + try: + caches = generate_caches() + yield caches + finally: + for cache in caches: + cache.destroy() + + async def inner_async(): + caches = [] + try: + caches = await generate_caches() + yield caches + finally: + await asyncio.gather(*[cache.destroy() for cache in caches]) + + return inner_async() if isinstance(client, AioClient) else inner() + + +@pytest.fixture +def cache(client): + cache = client.get_or_create_cache({ PROP_NAME: 'test_cache_1', PROP_CACHE_MODE: CacheMode.PARTITIONED, }) + try: + yield cache + finally: + cache.destroy() + +@pytest.fixture +async def async_cache(async_client): + cache = await async_client.get_or_create_cache({ + PROP_NAME: 'test_cache_1', + PROP_CACHE_MODE: CacheMode.PARTITIONED, + }) + try: + yield cache + finally: + await cache.destroy() + + +affinity_primitives_params = [ + # integers + (42, None), + (43, ByteObject), + (-44, ByteObject), + (45, IntObject), + (-46, IntObject), + (47, ShortObject), + (-48, ShortObject), + (49, LongObject), + (MAX_INT - 50, LongObject), + (MAX_INT + 51, LongObject), + + # floating point + (5.2, None), + (5.354, FloatObject), + (-5.556, FloatObject), + (-57.58, DoubleObject), + + # boolean + (True, None), + (True, BoolObject), + (False, BoolObject), + + # char + ('A', CharObject), + ('Z', CharObject), + ('⅓', CharObject), + ('á', CharObject), + ('ы', CharObject), + ('カ', CharObject), + ('Ø', CharObject), + ('ß', CharObject), + + # string + ('This is a test string', None), + ('Кириллица', None), + ('Little Mary had a lamb', String), + + # UUID + (UUID('12345678123456789876543298765432'), None), + (UUID('74274274274274274274274274274274'), UUIDObject), + (uuid4(), None), + + # decimal (long internal representation in Java) + (decimal.Decimal('-234.567'), None), + (decimal.Decimal('200.0'), None), + (decimal.Decimal('123.456'), DecimalObject), + (decimal.Decimal('1.0'), None), + (decimal.Decimal('0.02'), None), + + # decimal (BigInteger internal representation in Java) + (decimal.Decimal('12345671234567123.45671234567'), None), + (decimal.Decimal('-845678456.7845678456784567845'), None), + + # date and time + (datetime(1980, 1, 1), None), + ((datetime(1980, 1, 1), 999), TimestampObject), + (timedelta(days=99), TimeObject) +] + + +@pytest.mark.parametrize('key, key_hint', affinity_primitives_params) +def test_affinity(client, cache, key, key_hint): + __check_best_node_calculation(client, cache, key, 42, key_hint=key_hint) + + +@pytest.mark.parametrize('key, key_hint', affinity_primitives_params) +@pytest.mark.asyncio +async def test_affinity_async(async_client, async_cache, key, key_hint): + await __check_best_node_calculation(async_client, async_cache, key, 42, key_hint=key_hint) + + +@pytest.fixture +def key_generic_object(): class KeyClass( metaclass=GenericObjectMeta, schema={ @@ -195,24 +282,47 @@ class KeyClass( pass key = KeyClass() - key.NO = 2 - key.NAME = 'another_test_string' + key.NO = 1 + key.NAME = 'test_string' + yield key + - cache_1.put(key, 42) +@pytest.mark.parametrize('with_type_hint', [True, False]) +def test_affinity_for_generic_object(client, cache, key_generic_object, with_type_hint): + key_hint = BinaryObject if with_type_hint else None + __check_best_node_calculation(client, cache, key_generic_object, 42, key_hint=key_hint) - best_node = cache_1.get_best_node(key) - for node in filter(lambda n: n.alive, client._nodes): - result = cache_local_peek( - node, cache_1.cache_id, key - ) +@pytest.mark.parametrize('with_type_hint', [True, False]) +@pytest.mark.asyncio +async def test_affinity_for_generic_object_async(async_client, async_cache, key_generic_object, with_type_hint): + key_hint = BinaryObject if with_type_hint else None + await __check_best_node_calculation(async_client, async_cache, key_generic_object, 42, key_hint=key_hint) + + +def __check_best_node_calculation(client, cache, key, value, key_hint=None): + def check_peek_value(node, best_node, result): if node is best_node: - assert result.value == 42, ( - 'Affinity calculation error for {}'.format(key) - ) + assert result.value == value, f'Affinity calculation error for {key}' else: - assert result.value is None, ( - 'Affinity calculation error for {}'.format(key) - ) + assert result.value is None, f'Affinity calculation error for {key}' + + def inner(): + cache.put(key, value, key_hint=key_hint) + best_node = cache.get_best_node(key, key_hint=key_hint) + + for node in filter(lambda n: n.alive, client._nodes): + result = cache_local_peek(node, cache.cache_id, key, key_hint=key_hint) + + check_peek_value(node, best_node, result) + + async def inner_async(): + await cache.put(key, value, key_hint=key_hint) + best_node = await cache.get_best_node(key, key_hint=key_hint) + + for node in filter(lambda n: n.alive, client._nodes): + result = await cache_local_peek_async(node, cache.cache_id, key, key_hint=key_hint) + + check_peek_value(node, best_node, result) - cache_1.destroy() + return inner_async() if isinstance(client, AioClient) else inner() diff --git a/tests/affinity/test_affinity_bad_servers.py b/tests/affinity/test_affinity_bad_servers.py index 6fd08d5..b169168 100644 --- a/tests/affinity/test_affinity_bad_servers.py +++ b/tests/affinity/test_affinity_bad_servers.py @@ -15,9 +15,9 @@ import pytest -from pyignite.exceptions import ReconnectError +from pyignite.exceptions import ReconnectError, connection_errors from tests.affinity.conftest import CLIENT_SOCKET_TIMEOUT -from tests.util import start_ignite, kill_process_tree, get_client +from tests.util import start_ignite, kill_process_tree, get_client, get_client_async @pytest.fixture(params=['with-partition-awareness', 'without-partition-awareness']) @@ -26,10 +26,16 @@ def with_partition_awareness(request): def test_client_with_multiple_bad_servers(with_partition_awareness): - with pytest.raises(ReconnectError) as e_info: + with pytest.raises(ReconnectError, match="Can not connect."): with get_client(partition_aware=with_partition_awareness) as client: client.connect([("127.0.0.1", 10900), ("127.0.0.1", 10901)]) - assert str(e_info.value) == "Can not connect." + + +@pytest.mark.asyncio +async def test_client_with_multiple_bad_servers_async(with_partition_awareness): + with pytest.raises(ReconnectError, match="Can not connect."): + async with get_client_async(partition_aware=with_partition_awareness) as client: + await client.connect([("127.0.0.1", 10900), ("127.0.0.1", 10901)]) def test_client_with_failed_server(request, with_partition_awareness): @@ -52,6 +58,27 @@ def test_client_with_failed_server(request, with_partition_awareness): kill_process_tree(srv.pid) +@pytest.mark.asyncio +async def test_client_with_failed_server_async(request, with_partition_awareness): + srv = start_ignite(idx=4) + try: + async with get_client_async(partition_aware=with_partition_awareness) as client: + await client.connect([("127.0.0.1", 10804)]) + cache = await client.get_or_create_cache(request.node.name) + await cache.put(1, 1) + kill_process_tree(srv.pid) + + if with_partition_awareness: + ex_class = (ReconnectError, ConnectionResetError) + else: + ex_class = ConnectionResetError + + with pytest.raises(ex_class): + await cache.get(1) + finally: + kill_process_tree(srv.pid) + + def test_client_with_recovered_server(request, with_partition_awareness): srv = start_ignite(idx=4) try: @@ -67,7 +94,7 @@ def test_client_with_recovered_server(request, with_partition_awareness): # First request may fail. try: cache.put(1, 2) - except: + except connection_errors: pass # Retry succeeds @@ -75,3 +102,29 @@ def test_client_with_recovered_server(request, with_partition_awareness): assert cache.get(1) == 2 finally: kill_process_tree(srv.pid) + + +@pytest.mark.asyncio +async def test_client_with_recovered_server_async(request, with_partition_awareness): + srv = start_ignite(idx=4) + try: + async with get_client_async(partition_aware=with_partition_awareness) as client: + await client.connect([("127.0.0.1", 10804)]) + cache = await client.get_or_create_cache(request.node.name) + await cache.put(1, 1) + + # Kill and restart server + kill_process_tree(srv.pid) + srv = start_ignite(idx=4) + + # First request may fail. + try: + await cache.put(1, 2) + except connection_errors: + pass + + # Retry succeeds + await cache.put(1, 2) + assert await cache.get(1) == 2 + finally: + kill_process_tree(srv.pid) diff --git a/tests/affinity/test_affinity_request_routing.py b/tests/affinity/test_affinity_request_routing.py index 101db39..64197ff 100644 --- a/tests/affinity/test_affinity_request_routing.py +++ b/tests/affinity/test_affinity_request_routing.py @@ -13,20 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio from collections import OrderedDict, deque +import random + import pytest -from pyignite import * -from pyignite.connection import Connection +from pyignite import GenericObjectMeta, AioClient, Client +from pyignite.aio_cache import AioCache +from pyignite.connection import Connection, AioConnection from pyignite.constants import PROTOCOL_BYTE_ORDER -from pyignite.datatypes import * +from pyignite.datatypes import String, LongObject from pyignite.datatypes.cache_config import CacheMode -from pyignite.datatypes.prop_codes import * -from tests.util import * - +from pyignite.datatypes.prop_codes import PROP_NAME, PROP_BACKUPS_NUMBER, PROP_CACHE_KEY_CONFIGURATION, PROP_CACHE_MODE +from tests.util import wait_for_condition, wait_for_condition_async, start_ignite, kill_process_tree requests = deque() old_send = Connection.send +old_send_async = AioConnection._send def patched_send(self, *args, **kwargs): @@ -40,13 +44,26 @@ def patched_send(self, *args, **kwargs): return old_send(self, *args, **kwargs) +async def patched_send_async(self, *args, **kwargs): + """Patched send function that push to queue idx of server to which request is routed.""" + buf = args[0] + if buf and len(buf) >= 6: + op_code = int.from_bytes(buf[4:6], byteorder=PROTOCOL_BYTE_ORDER) + # Filter only caches operation. + if 1000 <= op_code < 1100: + requests.append(self.port % 100) + return await old_send_async(self, *args, **kwargs) + + def setup_function(): requests.clear() Connection.send = patched_send + AioConnection._send = patched_send_async def teardown_function(): Connection.send = old_send + AioConnection.send = old_send_async def wait_for_affinity_distribution(cache, key, node_idx, timeout=30): @@ -68,6 +85,25 @@ def check_grid_idx(): f"got {real_node_idx} instead") +async def wait_for_affinity_distribution_async(cache, key, node_idx, timeout=30): + real_node_idx = 0 + + async def check_grid_idx(): + nonlocal real_node_idx + try: + await cache.get(key) + real_node_idx = requests.pop() + except (OSError, IOError): + return False + return real_node_idx == node_idx + + res = await wait_for_condition_async(check_grid_idx, timeout=timeout) + + if not res: + raise TimeoutError(f"failed to wait for affinity distribution, expected node_idx {node_idx}," + f"got {real_node_idx} instead") + + @pytest.mark.parametrize("key,grid_idx", [(1, 1), (2, 2), (3, 3), (4, 1), (5, 1), (6, 2), (11, 1), (13, 1), (19, 1)]) @pytest.mark.parametrize("backups", [0, 1, 2, 3]) def test_cache_operation_on_primitive_key_routes_request_to_primary_node(request, key, grid_idx, backups, client): @@ -75,52 +111,56 @@ def test_cache_operation_on_primitive_key_routes_request_to_primary_node(request PROP_NAME: request.node.name + str(backups), PROP_BACKUPS_NUMBER: backups, }) + try: + __perform_operations_on_primitive_key(client, cache, key, grid_idx) + finally: + cache.destroy() - cache.put(key, key) - wait_for_affinity_distribution(cache, key, grid_idx) - - # Test - cache.get(key) - assert requests.pop() == grid_idx - - cache.put(key, key) - assert requests.pop() == grid_idx - - cache.replace(key, key + 1) - assert requests.pop() == grid_idx - - cache.clear_key(key) - assert requests.pop() == grid_idx - - cache.contains_key(key) - assert requests.pop() == grid_idx - cache.get_and_put(key, 3) - assert requests.pop() == grid_idx +@pytest.mark.parametrize("key,grid_idx", [(1, 1), (2, 2), (3, 3), (4, 1), (5, 1), (6, 2), (11, 1), (13, 1), (19, 1)]) +@pytest.mark.parametrize("backups", [0, 1, 2, 3]) +@pytest.mark.asyncio +async def test_cache_operation_on_primitive_key_routes_request_to_primary_node_async( + request, key, grid_idx, backups, async_client): + cache = await async_client.get_or_create_cache({ + PROP_NAME: request.node.name + str(backups), + PROP_BACKUPS_NUMBER: backups, + }) + try: + await __perform_operations_on_primitive_key(async_client, cache, key, grid_idx) + finally: + await cache.destroy() - cache.get_and_put_if_absent(key, 4) - assert requests.pop() == grid_idx - cache.put_if_absent(key, 5) - assert requests.pop() == grid_idx +def __perform_operations_on_primitive_key(client, cache, key, grid_idx): + operations = [ + ('get', 1), ('put', 2), ('replace', 2), ('clear_key', 1), ('contains_key', 1), ('get_and_put', 2), + ('get_and_put_if_absent', 2), ('put_if_absent', 2), ('get_and_remove', 1), ('get_and_replace', 2), + ('remove_key', 1), ('remove_if_equals', 2), ('replace', 2), ('replace_if_equals', 3) + ] - cache.get_and_remove(key) - assert requests.pop() == grid_idx + def inner(): + cache.put(key, key) + wait_for_affinity_distribution(cache, key, grid_idx) - cache.get_and_replace(key, 6) - assert requests.pop() == grid_idx + for op_name, param_nums in operations: + op = getattr(cache, op_name) + args = [random.randint(-100, 100) for _ in range(0, param_nums - 1)] + op(key, *args) + assert requests.pop() == grid_idx - cache.remove_key(key) - assert requests.pop() == grid_idx + async def inner_async(): + await cache.put(key, key) + await wait_for_affinity_distribution_async(cache, key, grid_idx) - cache.remove_if_equals(key, -1) - assert requests.pop() == grid_idx + for op_name, param_nums in operations: + op = getattr(cache, op_name) + args = [random.randint(-100, 100) for _ in range(0, param_nums - 1)] + await op(key, *args) - cache.replace(key, -1) - assert requests.pop() == grid_idx + assert requests.pop() == grid_idx - cache.replace_if_equals(key, 10, -10) - assert requests.pop() == grid_idx + return inner_async() if isinstance(client, AioClient) else inner() @pytest.mark.skip(reason="Custom key objects are not supported yet") @@ -164,50 +204,144 @@ class AffinityTestType1( assert requests.pop() == grid_idx -def test_cache_operation_routed_to_new_cluster_node(request, client_not_connected): - client_not_connected.connect( - [("127.0.0.1", 10801), ("127.0.0.1", 10802), ("127.0.0.1", 10803), ("127.0.0.1", 10804)] - ) - cache = client_not_connected.get_or_create_cache(request.node.name) - key = 12 - wait_for_affinity_distribution(cache, key, 3) - cache.put(key, key) - cache.put(key, key) - assert requests.pop() == 3 +client_routed_connection_string = [('127.0.0.1', 10800 + idx) for idx in range(1, 5)] + - srv = start_ignite(idx=4) +@pytest.fixture +def client_routed_cache(request): + client = Client(partition_aware=True) try: - # Wait for rebalance and partition map exchange - wait_for_affinity_distribution(cache, key, 4) + client.connect(client_routed_connection_string) + yield client.get_or_create_cache(request.node.name) + finally: + client.close() + - # Response is correct and comes from the new node - res = cache.get_and_remove(key) - assert res == key - assert requests.pop() == 4 +@pytest.fixture +async def async_client_routed_cache(request): + client = AioClient(partition_aware=True) + try: + await client.connect(client_routed_connection_string) + yield await client.get_or_create_cache(request.node.name) finally: - kill_process_tree(srv.pid) + await client.close() + + +def test_cache_operation_routed_to_new_cluster_node(client_routed_cache): + __perform_cache_operation_routed_to_new_node(client_routed_cache) + + +@pytest.mark.asyncio +async def test_cache_operation_routed_to_new_cluster_node_async(async_client_routed_cache): + await __perform_cache_operation_routed_to_new_node(async_client_routed_cache) + + +def __perform_cache_operation_routed_to_new_node(cache): + key = 12 + + def inner(): + wait_for_affinity_distribution(cache, key, 3) + cache.put(key, key) + cache.put(key, key) + assert requests.pop() == 3 + + srv = start_ignite(idx=4) + try: + # Wait for rebalance and partition map exchange + wait_for_affinity_distribution(cache, key, 4) + + # Response is correct and comes from the new node + res = cache.get_and_remove(key) + assert res == key + assert requests.pop() == 4 + finally: + kill_process_tree(srv.pid) + + async def inner_async(): + await wait_for_affinity_distribution_async(cache, key, 3) + await cache.put(key, key) + await cache.put(key, key) + assert requests.pop() == 3 + + srv = start_ignite(idx=4) + try: + # Wait for rebalance and partition map exchange + await wait_for_affinity_distribution_async(cache, key, 4) + + # Response is correct and comes from the new node + res = await cache.get_and_remove(key) + assert res == key + assert requests.pop() == 4 + finally: + kill_process_tree(srv.pid) + + return inner_async() if isinstance(cache, AioCache) else inner() -def test_replicated_cache_operation_routed_to_random_node(request, client): +@pytest.fixture +def replicated_cache(request, client): cache = client.get_or_create_cache({ PROP_NAME: request.node.name, PROP_CACHE_MODE: CacheMode.REPLICATED, }) + try: + yield cache + finally: + cache.destroy() + + +@pytest.fixture +async def async_replicated_cache(request, async_client): + cache = await async_client.get_or_create_cache({ + PROP_NAME: request.node.name, + PROP_CACHE_MODE: CacheMode.REPLICATED, + }) + try: + yield cache + finally: + await cache.destroy() - verify_random_node(cache) + +def test_replicated_cache_operation_routed_to_random_node(replicated_cache): + verify_random_node(replicated_cache) + + +@pytest.mark.asyncio +async def test_replicated_cache_operation_routed_to_random_node_async(async_replicated_cache): + await verify_random_node(async_replicated_cache) def verify_random_node(cache): key = 1 - cache.put(key, key) - idx1 = requests.pop() - idx2 = idx1 - - # Try 10 times - random node may end up being the same - for _ in range(1, 10): + def inner(): cache.put(key, key) - idx2 = requests.pop() - if idx2 != idx1: - break - assert idx1 != idx2 + + idx1 = requests.pop() + idx2 = idx1 + + # Try 10 times - random node may end up being the same + for _ in range(1, 10): + cache.put(key, key) + idx2 = requests.pop() + if idx2 != idx1: + break + assert idx1 != idx2 + + async def inner_async(): + await cache.put(key, key) + + idx1 = requests.pop() + + idx2 = idx1 + + # Try 10 times - random node may end up being the same + for _ in range(1, 10): + await cache.put(key, key) + idx2 = requests.pop() + + if idx2 != idx1: + break + assert idx1 != idx2 + + return inner_async() if isinstance(cache, AioCache) else inner() diff --git a/tests/affinity/test_affinity_single_connection.py b/tests/affinity/test_affinity_single_connection.py index 0768011..c3d2473 100644 --- a/tests/affinity/test_affinity_single_connection.py +++ b/tests/affinity/test_affinity_single_connection.py @@ -15,15 +15,27 @@ import pytest -from pyignite import Client +from pyignite import Client, AioClient -@pytest.fixture(scope='module') +@pytest.fixture def client(): client = Client(partition_aware=True) - client.connect('127.0.0.1', 10801) - yield client - client.close() + try: + client.connect('127.0.0.1', 10801) + yield client + finally: + client.close() + + +@pytest.fixture +async def async_client(): + client = AioClient(partition_aware=True) + try: + await client.connect('127.0.0.1', 10801) + yield client + finally: + await client.close() def test_all_cache_operations_with_partition_aware_client_on_single_server(request, client): @@ -108,3 +120,88 @@ def test_all_cache_operations_with_partition_aware_client_on_single_server(reque assert not res assert res2 assert cache.get(key) == key2 + + +@pytest.mark.asyncio +async def test_all_cache_operations_with_partition_aware_client_on_single_server_async(request, async_client): + cache = await async_client.get_or_create_cache(request.node.name) + key = 1 + key2 = 2 + + # Put/Get + await cache.put(key, key) + assert await cache.get(key) == key + + # Replace + res = await cache.replace(key, key2) + assert res + assert await cache.get(key) == key2 + + # Clear + await cache.put(key2, key2) + await cache.clear_key(key2) + assert await cache.get(key2) is None + + # ContainsKey + assert await cache.contains_key(key) + assert not await cache.contains_key(key2) + + # GetAndPut + await cache.put(key, key) + res = await cache.get_and_put(key, key2) + assert res == key + assert await cache.get(key) == key2 + + # GetAndPutIfAbsent + await cache.clear_key(key) + res = await cache.get_and_put_if_absent(key, key) + res2 = await cache.get_and_put_if_absent(key, key2) + assert res is None + assert res2 == key + assert await cache.get(key) == key + + # PutIfAbsent + await cache.clear_key(key) + res = await cache.put_if_absent(key, key) + res2 = await cache.put_if_absent(key, key2) + assert res + assert not res2 + assert await cache.get(key) == key + + # GetAndRemove + await cache.put(key, key) + res = await cache.get_and_remove(key) + assert res == key + assert await cache.get(key) is None + + # GetAndReplace + await cache.put(key, key) + res = await cache.get_and_replace(key, key2) + assert res == key + assert await cache.get(key) == key2 + + # RemoveKey + await cache.put(key, key) + await cache.remove_key(key) + assert await cache.get(key) is None + + # RemoveIfEquals + await cache.put(key, key) + res = await cache.remove_if_equals(key, key2) + res2 = await cache.remove_if_equals(key, key) + assert not res + assert res2 + assert await cache.get(key) is None + + # Replace + await cache.put(key, key) + await cache.replace(key, key2) + assert await cache.get(key) == key2 + + # ReplaceIfEquals + await cache.put(key, key) + res = await cache.replace_if_equals(key, key2, key2) + res2 = await cache.replace_if_equals(key, key, key2) + assert not res + assert res2 + assert await cache.get(key) == key2 diff --git a/tests/common/conftest.py b/tests/common/conftest.py index 402aede..243d822 100644 --- a/tests/common/conftest.py +++ b/tests/common/conftest.py @@ -15,8 +15,7 @@ import pytest -from pyignite import Client -from pyignite.api import cache_create, cache_destroy +from pyignite import Client, AioClient from tests.util import start_ignite_gen @@ -38,19 +37,36 @@ def server3(): @pytest.fixture(scope='module') def client(): client = Client() + try: + client.connect('127.0.0.1', 10801) + yield client + finally: + client.close() - client.connect('127.0.0.1', 10801) - yield client +@pytest.fixture(scope='module') +async def async_client(event_loop): + client = AioClient() + try: + await client.connect('127.0.0.1', 10801) + yield client + finally: + await client.close() + - client.close() +@pytest.fixture +async def async_cache(async_client: 'AioClient'): + cache = await async_client.create_cache('my_bucket') + try: + yield cache + finally: + await cache.destroy() @pytest.fixture def cache(client): - cache_name = 'my_bucket' - conn = client.random_node - - cache_create(conn, cache_name) - yield cache_name - cache_destroy(conn, cache_name) + cache = client.create_cache('my_bucket') + try: + yield cache + finally: + cache.destroy() diff --git a/tests/common/test_binary.py b/tests/common/test_binary.py index 5fa2ec4..1d7192f 100644 --- a/tests/common/test_binary.py +++ b/tests/common/test_binary.py @@ -16,15 +16,17 @@ from collections import OrderedDict from decimal import Decimal +import pytest + from pyignite import GenericObjectMeta +from pyignite.aio_cache import AioCache from pyignite.datatypes import ( BinaryObject, BoolObject, IntObject, DecimalObject, LongObject, String, ByteObject, ShortObject, FloatObject, DoubleObject, CharObject, UUIDObject, DateObject, TimestampObject, TimeObject, EnumObject, BinaryEnumObject, ByteArrayObject, ShortArrayObject, IntArrayObject, LongArrayObject, FloatArrayObject, DoubleArrayObject, CharArrayObject, BoolArrayObject, UUIDArrayObject, DateArrayObject, TimestampArrayObject, TimeArrayObject, EnumArrayObject, StringArrayObject, DecimalArrayObject, ObjectArrayObject, CollectionObject, MapObject) -from pyignite.datatypes.prop_codes import * - +from pyignite.datatypes.prop_codes import PROP_NAME, PROP_SQL_SCHEMA, PROP_QUERY_ENTITIES insert_data = [ [1, True, 'asdf', 42, Decimal('2.4')], @@ -54,7 +56,7 @@ insert_query = ''' INSERT INTO {} ( - test_pk, test_bool, test_str, test_int, test_decimal, + test_pk, test_bool, test_str, test_int, test_decimal, ) VALUES (?, ?, ?, ?, ?)'''.format(table_sql_name) select_query = '''SELECT * FROM {}'''.format(table_sql_name) @@ -62,51 +64,69 @@ drop_query = 'DROP TABLE {} IF EXISTS'.format(table_sql_name) -def test_sql_read_as_binary(client): +@pytest.fixture +def table_cache_read(client): client.sql(drop_query) - - # create table client.sql(create_query) - # insert some rows for line in insert_data: client.sql(insert_query, query_args=line) - table_cache = client.get_cache(table_cache_name) - result = table_cache.scan() - - # convert Binary object fields' values to a tuple - # to compare it with the initial data - for key, value in result: - assert key in {x[0] for x in insert_data} - assert ( - value.TEST_BOOL, - value.TEST_STR, - value.TEST_INT, - value.TEST_DECIMAL - ) in {tuple(x[1:]) for x in insert_data} - - client.sql(drop_query) + cache = client.get_cache(table_cache_name) + yield cache + cache.destroy() -def test_sql_write_as_binary(client): - # configure cache as an SQL table - type_name = table_cache_name +@pytest.fixture +async def table_cache_read_async(async_client): + await async_client.sql(drop_query) + await async_client.sql(create_query) - # register binary type - class AllDataType( - metaclass=GenericObjectMeta, - type_name=type_name, - schema=OrderedDict([ - ('TEST_BOOL', BoolObject), - ('TEST_STR', String), - ('TEST_INT', IntObject), - ('TEST_DECIMAL', DecimalObject), - ]), - ): - pass - - table_cache = client.get_or_create_cache({ + for line in insert_data: + await async_client.sql(insert_query, query_args=line) + + cache = await async_client.get_cache(table_cache_name) + yield cache + await cache.destroy() + + +def test_sql_read_as_binary(table_cache_read): + with table_cache_read.scan() as cursor: + # convert Binary object fields' values to a tuple + # to compare it with the initial data + for key, value in cursor: + assert key in {x[0] for x in insert_data} + assert (value.TEST_BOOL, value.TEST_STR, value.TEST_INT, value.TEST_DECIMAL) \ + in {tuple(x[1:]) for x in insert_data} + + +@pytest.mark.asyncio +async def test_sql_read_as_binary_async(table_cache_read_async): + async with table_cache_read_async.scan() as cursor: + # convert Binary object fields' values to a tuple + # to compare it with the initial data + async for key, value in cursor: + assert key in {x[0] for x in insert_data} + assert (value.TEST_BOOL, value.TEST_STR, value.TEST_INT, value.TEST_DECIMAL) \ + in {tuple(x[1:]) for x in insert_data} + + +class AllDataType( + metaclass=GenericObjectMeta, + type_name=table_cache_name, + schema=OrderedDict([ + ('TEST_BOOL', BoolObject), + ('TEST_STR', String), + ('TEST_INT', IntObject), + ('TEST_DECIMAL', DecimalObject), + ]), +): + pass + + +@pytest.fixture +def table_cache_write_settings(): + return { PROP_NAME: table_cache_name, PROP_SQL_SCHEMA: scheme_name, PROP_QUERY_ENTITIES: [ @@ -142,15 +162,18 @@ class AllDataType( }, ], 'query_indexes': [], - 'value_type_name': type_name, + 'value_type_name': table_cache_name, 'value_field_name': None, }, ], - }) - table_settings = table_cache.settings - assert table_settings, 'SQL table cache settings are empty' + } + + +@pytest.fixture +def table_cache_write(client, table_cache_write_settings): + cache = client.get_or_create_cache(table_cache_write_settings) + assert cache.settings, 'SQL table cache settings are empty' - # insert rows as k-v for row in insert_data: value = AllDataType() ( @@ -159,13 +182,39 @@ class AllDataType( value.TEST_INT, value.TEST_DECIMAL, ) = row[1:] - table_cache.put(row[0], value, key_hint=IntObject) + cache.put(row[0], value, key_hint=IntObject) + + data = cache.scan() + assert len(list(data)) == len(insert_data), 'Not all data was read as key-value' + + yield cache + cache.destroy() - data = table_cache.scan() - assert len(list(data)) == len(insert_data), ( - 'Not all data was read as key-value' - ) +@pytest.fixture +async def async_table_cache_write(async_client, table_cache_write_settings): + cache = await async_client.get_or_create_cache(table_cache_write_settings) + assert await cache.settings(), 'SQL table cache settings are empty' + + for row in insert_data: + value = AllDataType() + ( + value.TEST_BOOL, + value.TEST_STR, + value.TEST_INT, + value.TEST_DECIMAL, + ) = row[1:] + await cache.put(row[0], value, key_hint=IntObject) + + async with cache.scan() as cursor: + data = [a async for a in cursor] + assert len(data) == len(insert_data), 'Not all data was read as key-value' + + yield cache + await cache.destroy() + + +def test_sql_write_as_binary(client, table_cache_write): # read rows as SQL data = client.sql(select_query, include_field_names=True) @@ -176,14 +225,29 @@ class AllDataType( data = list(data) assert len(data) == len(insert_data), 'Not all data was read as SQL rows' - # cleanup - table_cache.destroy() + +@pytest.mark.asyncio +async def test_sql_write_as_binary_async(async_client, async_table_cache_write): + # read rows as SQL + async with async_client.sql(select_query, include_field_names=True) as cursor: + header_row = await cursor.__anext__() + for field_name in AllDataType.schema.keys(): + assert field_name in header_row, 'Not all field names in header row' + + data = [v async for v in cursor] + assert len(data) == len(insert_data), 'Not all data was read as SQL rows' -def test_nested_binary_objects(client): +def test_nested_binary_objects(cache): + __check_nested_binary_objects(cache) - nested_cache = client.get_or_create_cache('nested_binary') +@pytest.mark.asyncio +async def test_nested_binary_objects_async(async_cache): + await __check_nested_binary_objects(async_cache) + + +def __check_nested_binary_objects(cache): class InnerType( metaclass=GenericObjectMeta, schema=OrderedDict([ @@ -203,29 +267,42 @@ class OuterType( ): pass - inner = InnerType(inner_int=42, inner_str='This is a test string') + def prepare_obj(): + inner = InnerType(inner_int=42, inner_str='This is a test string') + + return OuterType( + outer_int=43, + nested_binary=inner, + outer_str='This is another test string' + ) - outer = OuterType( - outer_int=43, - nested_binary=inner, - outer_str='This is another test string' - ) + def check_obj(result): + assert result.outer_int == 43 + assert result.outer_str == 'This is another test string' + assert result.nested_binary.inner_int == 42 + assert result.nested_binary.inner_str == 'This is a test string' - nested_cache.put(1, outer) + async def inner_async(): + await cache.put(1, prepare_obj()) + check_obj(await cache.get(1)) - result = nested_cache.get(1) - assert result.outer_int == 43 - assert result.outer_str == 'This is another test string' - assert result.nested_binary.inner_int == 42 - assert result.nested_binary.inner_str == 'This is a test string' + def inner(): + cache.put(1, prepare_obj()) + check_obj(cache.get(1)) - nested_cache.destroy() + return inner_async() if isinstance(cache, AioCache) else inner() -def test_add_schema_to_binary_object(client): +def test_add_schema_to_binary_object(cache): + __check_add_schema_to_binary_object(cache) - migrate_cache = client.get_or_create_cache('migrate_binary') +@pytest.mark.asyncio +async def test_add_schema_to_binary_object_async(async_cache): + await __check_add_schema_to_binary_object(async_cache) + + +def __check_add_schema_to_binary_object(cache): class MyBinaryType( metaclass=GenericObjectMeta, schema=OrderedDict([ @@ -236,54 +313,66 @@ class MyBinaryType( ): pass - binary_object = MyBinaryType( - test_str='Test string', - test_int=42, - test_bool=True, - ) - migrate_cache.put(1, binary_object) + def prepare_bo_v1(): + return MyBinaryType(test_str='Test string', test_int=42, test_bool=True) - result = migrate_cache.get(1) - assert result.test_str == 'Test string' - assert result.test_int == 42 - assert result.test_bool is True + def check_bo_v1(result): + assert result.test_str == 'Test string' + assert result.test_int == 42 + assert result.test_bool is True - modified_schema = MyBinaryType.schema.copy() - modified_schema['test_decimal'] = DecimalObject - del modified_schema['test_bool'] + def prepare_bo_v2(): + modified_schema = MyBinaryType.schema.copy() + modified_schema['test_decimal'] = DecimalObject + del modified_schema['test_bool'] - class MyBinaryTypeV2( - metaclass=GenericObjectMeta, - type_name='MyBinaryType', - schema=modified_schema, - ): - pass + class MyBinaryTypeV2( + metaclass=GenericObjectMeta, + type_name='MyBinaryType', + schema=modified_schema, + ): + pass + + assert MyBinaryType.type_id == MyBinaryTypeV2.type_id + assert MyBinaryType.schema_id != MyBinaryTypeV2.schema_id - assert MyBinaryType.type_id == MyBinaryTypeV2.type_id - assert MyBinaryType.schema_id != MyBinaryTypeV2.schema_id + return MyBinaryTypeV2(test_str='Another test', test_int=43, test_decimal=Decimal('2.34')) - binary_object_v2 = MyBinaryTypeV2( - test_str='Another test', - test_int=43, - test_decimal=Decimal('2.34') - ) + def check_bo_v2(result): + assert result.test_str == 'Another test' + assert result.test_int == 43 + assert result.test_decimal == Decimal('2.34') + assert not hasattr(result, 'test_bool') - migrate_cache.put(2, binary_object_v2) + async def inner_async(): + await cache.put(1, prepare_bo_v1()) + check_bo_v1(await cache.get(1)) + await cache.put(2, prepare_bo_v2()) + check_bo_v2(await cache.get(2)) - result = migrate_cache.get(2) - assert result.test_str == 'Another test' - assert result.test_int == 43 - assert result.test_decimal == Decimal('2.34') - assert not hasattr(result, 'test_bool') + def inner(): + cache.put(1, prepare_bo_v1()) + check_bo_v1(cache.get(1)) + cache.put(2, prepare_bo_v2()) + check_bo_v2(cache.get(2)) - migrate_cache.destroy() + return inner_async() if isinstance(cache, AioCache) else inner() -def test_complex_object_names(client): +def test_complex_object_names(cache): """ Test the ability to work with Complex types, which names contains symbols not suitable for use in Python identifiers. """ + __check_complex_object_names(cache) + + +@pytest.mark.asyncio +async def test_complex_object_names_async(async_cache): + await __check_complex_object_names(async_cache) + + +def __check_complex_object_names(cache): type_name = 'Non.Pythonic#type-name$' key = 'key' data = 'test' @@ -297,41 +386,47 @@ class NonPythonicallyNamedType( ): pass - cache = client.get_or_create_cache('test_name_cache') - cache.put(key, NonPythonicallyNamedType(field=data)) + def check(obj): + assert obj.type_name == type_name, 'Complex type name mismatch' + assert obj.field == data, 'Complex object data failure' - obj = cache.get(key) - assert obj.type_name == type_name, 'Complex type name mismatch' - assert obj.field == data, 'Complex object data failure' + async def inner_async(): + await cache.put(key, NonPythonicallyNamedType(field=data)) + check(await cache.get(key)) + def inner(): + cache.put(key, NonPythonicallyNamedType(field=data)) + check(cache.get(key)) -def test_complex_object_hash(client): - """ - Test that Python client correctly calculates hash of the binary object that - contains negative bytes. - """ - class Internal( - metaclass=GenericObjectMeta, - type_name='Internal', - schema=OrderedDict([ - ('id', IntObject), - ('str', String), - ]) - ): - pass + return inner_async() if isinstance(cache, AioCache) else inner() - class TestObject( - metaclass=GenericObjectMeta, - type_name='TestObject', - schema=OrderedDict([ - ('id', IntObject), - ('str', String), - ('internal', BinaryObject), - ]) - ): - pass - obj_ascii = TestObject() +class Internal( + metaclass=GenericObjectMeta, type_name='Internal', + schema=OrderedDict([ + ('id', IntObject), + ('str', String) + ]) +): + pass + + +class NestedObject( + metaclass=GenericObjectMeta, type_name='NestedObject', + schema=OrderedDict([ + ('id', IntObject), + ('str', String), + ('internal', BinaryObject) + ]) +): + pass + + +@pytest.fixture +def complex_objects(): + fixtures = [] + + obj_ascii = NestedObject() obj_ascii.id = 1 obj_ascii.str = 'test_string' @@ -339,11 +434,9 @@ class TestObject( obj_ascii.internal.id = 2 obj_ascii.internal.str = 'lorem ipsum' - hash_ascii = BinaryObject.hashcode(obj_ascii, client=client) - - assert hash_ascii == -1314567146, 'Invalid hashcode value for object with ASCII strings' + fixtures.append((obj_ascii, -1314567146)) - obj_utf8 = TestObject() + obj_utf8 = NestedObject() obj_utf8.id = 1 obj_utf8.str = 'юникод' @@ -351,39 +444,63 @@ class TestObject( obj_utf8.internal.id = 2 obj_utf8.internal.str = 'ユニコード' - hash_utf8 = BinaryObject.hashcode(obj_utf8, client=client) + fixtures.append((obj_utf8, -1945378474)) - assert hash_utf8 == -1945378474, 'Invalid hashcode value for object with UTF-8 strings' + yield fixtures -def test_complex_object_null_fields(client): - """ - Test that Python client can correctly write and read binary object that - contains null fields. - """ - def camel_to_snake(name): - return re.sub('([a-z0-9])([A-Z])', r'\1_\2', name).lower() - - fields = {camel_to_snake(type_.__name__): type_ for type_ in [ - ByteObject, ShortObject, IntObject, LongObject, FloatObject, DoubleObject, CharObject, BoolObject, UUIDObject, - DateObject, TimestampObject, TimeObject, EnumObject, BinaryEnumObject, ByteArrayObject, ShortArrayObject, - IntArrayObject, LongArrayObject, FloatArrayObject, DoubleArrayObject, CharArrayObject, BoolArrayObject, - UUIDArrayObject, DateArrayObject, TimestampArrayObject, TimeArrayObject, EnumArrayObject, String, - StringArrayObject, DecimalObject, DecimalArrayObject, ObjectArrayObject, CollectionObject, MapObject, - BinaryObject]} - - class AllTypesObject(metaclass=GenericObjectMeta, type_name='AllTypesObject', schema=fields): - pass +def test_complex_object_hash(client, complex_objects): + for obj, hash in complex_objects: + assert hash == BinaryObject.hashcode(obj, client) + + +@pytest.mark.asyncio +async def test_complex_object_hash_async(async_client, complex_objects): + for obj, hash in complex_objects: + assert hash == await BinaryObject.hashcode_async(obj, async_client) + + +def camel_to_snake(name): + return re.sub('([a-z0-9])([A-Z])', r'\1_\2', name).lower() + + +fields = {camel_to_snake(type_.__name__): type_ for type_ in [ + ByteObject, ShortObject, IntObject, LongObject, FloatObject, DoubleObject, CharObject, BoolObject, UUIDObject, + DateObject, TimestampObject, TimeObject, EnumObject, BinaryEnumObject, ByteArrayObject, ShortArrayObject, + IntArrayObject, LongArrayObject, FloatArrayObject, DoubleArrayObject, CharArrayObject, BoolArrayObject, + UUIDArrayObject, DateArrayObject, TimestampArrayObject, TimeArrayObject, EnumArrayObject, String, + StringArrayObject, DecimalObject, DecimalArrayObject, ObjectArrayObject, CollectionObject, MapObject, + BinaryObject]} + + +class AllTypesObject(metaclass=GenericObjectMeta, type_name='AllTypesObject', schema=fields): + pass - key = 42 - null_fields_value = AllTypesObject() + +@pytest.fixture +def null_fields_object(): + res = AllTypesObject() for field in fields.keys(): - setattr(null_fields_value, field, None) + setattr(res, field, None) + + yield res - cache = client.get_or_create_cache('all_types_test_cache') - cache.put(key, null_fields_value) - got_obj = cache.get(key) +def test_complex_object_null_fields(cache, null_fields_object): + """ + Test that Python client can correctly write and read binary object that + contains null fields. + """ + cache.put(1, null_fields_object) + assert cache.get(1) == null_fields_object, 'Objects mismatch' - assert got_obj == null_fields_value, 'Objects mismatch' + +@pytest.mark.asyncio +async def test_complex_object_null_fields_async(async_cache, null_fields_object): + """ + Test that Python client can correctly write and read binary object that + contains null fields. + """ + await async_cache.put(1, null_fields_object) + assert await async_cache.get(1) == null_fields_object, 'Objects mismatch' diff --git a/tests/common/test_cache_class.py b/tests/common/test_cache_class.py index 940160a..02dfa82 100644 --- a/tests/common/test_cache_class.py +++ b/tests/common/test_cache_class.py @@ -19,66 +19,56 @@ import pytest from pyignite import GenericObjectMeta -from pyignite.datatypes import ( - BoolObject, DecimalObject, FloatObject, IntObject, String, -) -from pyignite.datatypes.prop_codes import * +from pyignite.datatypes import BoolObject, DecimalObject, FloatObject, IntObject, String +from pyignite.datatypes.prop_codes import PROP_NAME, PROP_CACHE_KEY_CONFIGURATION from pyignite.exceptions import CacheError, ParameterError def test_cache_create(client): cache = client.get_or_create_cache('my_oop_cache') - assert cache.name == cache.settings[PROP_NAME] == 'my_oop_cache' - cache.destroy() - - -def test_cache_remove(client): - cache = client.get_or_create_cache('my_cache') - cache.clear() - assert cache.get_size() == 0 - - cache.put_all({ - 'key_1': 1, - 'key_2': 2, - 'key_3': 3, - 'key_4': 4, - 'key_5': 5, - }) - assert cache.get_size() == 5 - - result = cache.remove_if_equals('key_1', 42) - assert result is False - assert cache.get_size() == 5 - - result = cache.remove_if_equals('key_1', 1) - assert result is True - assert cache.get_size() == 4 + try: + assert cache.name == cache.settings[PROP_NAME] == 'my_oop_cache' + finally: + cache.destroy() - cache.remove_keys(['key_1', 'key_3', 'key_5', 'key_7']) - assert cache.get_size() == 2 - cache.remove_all() - assert cache.get_size() == 0 +@pytest.mark.asyncio +async def test_cache_create_async(async_client): + cache = await async_client.get_or_create_cache('my_oop_cache') + try: + assert (await cache.name()) == (await cache.settings())[PROP_NAME] == 'my_oop_cache' + finally: + await cache.destroy() -def test_cache_get(client): +def test_get_cache(client): my_cache = client.get_or_create_cache('my_cache') - assert my_cache.settings[PROP_NAME] == 'my_cache' - my_cache.destroy() - - error = None + try: + assert my_cache.settings[PROP_NAME] == 'my_cache' + finally: + my_cache.destroy() my_cache = client.get_cache('my_cache') - try: + with pytest.raises(CacheError): _ = my_cache.settings[PROP_NAME] - except CacheError as e: - error = e - assert type(error) is CacheError + +@pytest.mark.asyncio +async def test_get_cache_async(async_client): + my_cache = await async_client.get_or_create_cache('my_cache') + try: + assert (await my_cache.settings())[PROP_NAME] == 'my_cache' + finally: + await my_cache.destroy() + + my_cache = await async_client.get_cache('my_cache') + with pytest.raises(CacheError): + _ = (await my_cache.settings())[PROP_NAME] -def test_cache_config(client): - cache_config = { +@pytest.fixture +def cache_config(): + yield { PROP_NAME: 'my_oop_cache', PROP_CACHE_KEY_CONFIGURATION: [ { @@ -87,28 +77,31 @@ def test_cache_config(client): }, ], } - client.create_cache(cache_config) - - cache = client.get_or_create_cache('my_oop_cache') - assert cache.name == cache_config[PROP_NAME] - assert ( - cache.settings[PROP_CACHE_KEY_CONFIGURATION] - == cache_config[PROP_CACHE_KEY_CONFIGURATION] - ) - cache.destroy() - -def test_cache_get_put(client): +def test_cache_config(client, cache_config): + client.create_cache(cache_config) cache = client.get_or_create_cache('my_oop_cache') - cache.put('my_key', 42) - result = cache.get('my_key') - assert result, 42 - cache.destroy() + try: + assert cache.name == cache_config[PROP_NAME] + assert cache.settings[PROP_CACHE_KEY_CONFIGURATION] == cache_config[PROP_CACHE_KEY_CONFIGURATION] + finally: + cache.destroy() -def test_cache_binary_get_put(client): +@pytest.mark.asyncio +async def test_cache_config_async(async_client, cache_config): + await async_client.create_cache(cache_config) + cache = await async_client.get_or_create_cache('my_oop_cache') + try: + assert await cache.name() == cache_config[PROP_NAME] + assert (await cache.settings())[PROP_CACHE_KEY_CONFIGURATION] == cache_config[PROP_CACHE_KEY_CONFIGURATION] + finally: + await cache.destroy() + +@pytest.fixture +def binary_type_fixture(): class TestBinaryType( metaclass=GenericObjectMeta, schema=OrderedDict([ @@ -120,52 +113,63 @@ class TestBinaryType( ): pass - cache = client.create_cache('my_oop_cache') - - my_value = TestBinaryType( + return TestBinaryType( test_bool=True, test_str='This is a test', test_int=42, test_decimal=Decimal('34.56'), ) - cache.put('my_key', my_value) + +def test_cache_binary_get_put(cache, binary_type_fixture): + cache.put('my_key', binary_type_fixture) value = cache.get('my_key') - assert value.test_bool is True - assert value.test_str == 'This is a test' - assert value.test_int == 42 - assert value.test_decimal == Decimal('34.56') + assert value.test_bool == binary_type_fixture.test_bool + assert value.test_str == binary_type_fixture.test_str + assert value.test_int == binary_type_fixture.test_int + assert value.test_decimal == binary_type_fixture.test_decimal - cache.destroy() +@pytest.mark.asyncio +async def test_cache_binary_get_put_async(async_cache, binary_type_fixture): + await async_cache.put('my_key', binary_type_fixture) -def test_get_binary_type(client): - client.put_binary_type( - 'TestBinaryType', - schema=OrderedDict([ + value = await async_cache.get('my_key') + assert value.test_bool == binary_type_fixture.test_bool + assert value.test_str == binary_type_fixture.test_str + assert value.test_int == binary_type_fixture.test_int + assert value.test_decimal == binary_type_fixture.test_decimal + + +@pytest.fixture +def binary_type_schemas_fixture(): + schemas = [ + OrderedDict([ ('TEST_BOOL', BoolObject), ('TEST_STR', String), ('TEST_INT', IntObject), - ]) - ) - client.put_binary_type( - 'TestBinaryType', - schema=OrderedDict([ + ]), + OrderedDict([ ('TEST_BOOL', BoolObject), ('TEST_STR', String), ('TEST_INT', IntObject), ('TEST_FLOAT', FloatObject), - ]) - ) - client.put_binary_type( - 'TestBinaryType', - schema=OrderedDict([ + ]), + OrderedDict([ ('TEST_BOOL', BoolObject), ('TEST_STR', String), ('TEST_INT', IntObject), ('TEST_DECIMAL', DecimalObject), ]) - ) + ] + yield 'TestBinaryType', schemas + + +def test_get_binary_type(client, binary_type_schemas_fixture): + type_name, schemas = binary_type_schemas_fixture + + for schema in schemas: + client.put_binary_type(type_name, schema=schema) binary_type_info = client.get_binary_type('TestBinaryType') assert len(binary_type_info['schemas']) == 3 @@ -175,60 +179,37 @@ def test_get_binary_type(client): assert len(binary_type_info) == 1 -@pytest.mark.parametrize('page_size', range(1, 17, 5)) -def test_cache_scan(request, client, page_size): - test_data = { - 1: 'This is a test', - 2: 'One more test', - 3: 'Foo', - 4: 'Buzz', - 5: 'Bar', - 6: 'Lorem ipsum', - 7: 'dolor sit amet', - 8: 'consectetur adipiscing elit', - 9: 'Nullam aliquet', - 10: 'nisl at ante', - 11: 'suscipit', - 12: 'ut cursus', - 13: 'metus interdum', - 14: 'Nulla tincidunt', - 15: 'sollicitudin iaculis', - } - - cache = client.get_or_create_cache(request.node.name) - cache.put_all(test_data) - - gen = cache.scan(page_size=page_size) - received_data = [] - for k, v in gen: - assert k in test_data.keys() - assert v in test_data.values() - received_data.append((k, v)) - assert len(received_data) == len(test_data) +@pytest.mark.asyncio +async def test_get_binary_type_async(async_client, binary_type_schemas_fixture): + type_name, schemas = binary_type_schemas_fixture - cache.destroy() + for schema in schemas: + await async_client.put_binary_type(type_name, schema=schema) + binary_type_info = await async_client.get_binary_type('TestBinaryType') + assert len(binary_type_info['schemas']) == 3 -def test_get_and_put_if_absent(client): - cache = client.get_or_create_cache('my_oop_cache') - - value = cache.get_and_put_if_absent('my_key', 42) - assert value is None - cache.put('my_key', 43) - value = cache.get_and_put_if_absent('my_key', 42) - assert value is 43 + binary_type_info = await async_client.get_binary_type('NonExistentType') + assert binary_type_info['type_exists'] is False + assert len(binary_type_info) == 1 -def test_cache_get_when_cache_does_not_exist(client): +def test_get_cache_errors(client): cache = client.get_cache('missing-cache') - with pytest.raises(CacheError) as e_info: - cache.put(1, 1) - assert str(e_info.value) == "Cache does not exist [cacheId= 1665146971]" + with pytest.raises(CacheError, match=r'Cache does not exist \[cacheId='): + cache.put(1, 1) -def test_cache_create_with_none_name(client): - with pytest.raises(ParameterError) as e_info: + with pytest.raises(ParameterError, match="You should supply at least cache name"): client.create_cache(None) - assert str(e_info.value) == "You should supply at least cache name" +@pytest.mark.asyncio +async def test_get_cache_errors_async(async_client): + cache = await async_client.get_cache('missing-cache') + + with pytest.raises(CacheError, match=r'Cache does not exist \[cacheId='): + await cache.put(1, 1) + + with pytest.raises(ParameterError, match="You should supply at least cache name"): + await async_client.create_cache(None) diff --git a/tests/common/test_cache_class_sql.py b/tests/common/test_cache_class_sql.py deleted file mode 100644 index 5f72b39..0000000 --- a/tests/common/test_cache_class_sql.py +++ /dev/null @@ -1,103 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest - - -initial_data = [ - ('John', 'Doe', 5), - ('Jane', 'Roe', 4), - ('Joe', 'Bloggs', 4), - ('Richard', 'Public', 3), - ('Negidius', 'Numerius', 3), - ] - -create_query = '''CREATE TABLE Student ( - id INT(11) PRIMARY KEY, - first_name CHAR(24), - last_name CHAR(32), - grade INT(11))''' - -insert_query = '''INSERT INTO Student(id, first_name, last_name, grade) -VALUES (?, ?, ?, ?)''' - -select_query = 'SELECT id, first_name, last_name, grade FROM Student' - -drop_query = 'DROP TABLE Student IF EXISTS' - - -@pytest.mark.parametrize('page_size', range(1, 6, 2)) -def test_sql_fields(client, page_size): - - client.sql(drop_query, page_size) - - result = client.sql(create_query, page_size) - assert next(result)[0] == 0 - - for i, data_line in enumerate(initial_data, start=1): - fname, lname, grade = data_line - result = client.sql( - insert_query, - page_size, - query_args=[i, fname, lname, grade] - ) - assert next(result)[0] == 1 - - result = client.sql( - select_query, - page_size, - include_field_names=True, - ) - field_names = next(result) - assert set(field_names) == {'ID', 'FIRST_NAME', 'LAST_NAME', 'GRADE'} - - data = list(result) - assert len(data) == 5 - for row in data: - assert len(row) == 4 - - client.sql(drop_query, page_size) - - -@pytest.mark.parametrize('page_size', range(1, 6, 2)) -def test_sql(client, page_size): - - client.sql(drop_query, page_size) - - result = client.sql(create_query, page_size) - assert next(result)[0] == 0 - - for i, data_line in enumerate(initial_data, start=1): - fname, lname, grade = data_line - result = client.sql( - insert_query, - page_size, - query_args=[i, fname, lname, grade] - ) - assert next(result)[0] == 1 - - student = client.get_or_create_cache('SQL_PUBLIC_STUDENT') - result = student.select_row('TRUE', page_size) - for k, v in result: - assert k in range(1, 6) - assert v.FIRST_NAME in [ - 'John', - 'Jane', - 'Joe', - 'Richard', - 'Negidius', - ] - - client.sql(drop_query, page_size) diff --git a/tests/common/test_cache_composite_key_class_sql.py b/tests/common/test_cache_composite_key_class_sql.py deleted file mode 100644 index 989a229..0000000 --- a/tests/common/test_cache_composite_key_class_sql.py +++ /dev/null @@ -1,122 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import OrderedDict - -from pyignite import GenericObjectMeta -from pyignite.datatypes import ( - IntObject, String -) - - -class StudentKey( - metaclass=GenericObjectMeta, - type_name='test.model.StudentKey', - schema=OrderedDict([ - ('ID', IntObject), - ('DEPT', String) - ]) - ): - pass - - -class Student( - metaclass=GenericObjectMeta, - type_name='test.model.Student', - schema=OrderedDict([ - ('NAME', String), - ]) - ): - pass - - -create_query = '''CREATE TABLE StudentTable ( - id INT(11), - dept VARCHAR, - name CHAR(24), - PRIMARY KEY (id, dept)) - WITH "CACHE_NAME=StudentCache, KEY_TYPE=test.model.StudentKey, VALUE_TYPE=test.model.Student"''' - -insert_query = '''INSERT INTO StudentTable (id, dept, name) VALUES (?, ?, ?)''' - -select_query = 'SELECT _KEY, id, dept, name FROM StudentTable' - -drop_query = 'DROP TABLE StudentTable IF EXISTS' - - -def test_cache_get_with_composite_key_finds_sql_value(client): - """ - Should query a record with composite key and calculate - internal hashcode correctly. - """ - - client.sql(drop_query) - - # Create table. - result = client.sql(create_query) - assert next(result)[0] == 0 - - student_key = StudentKey(1, 'Acct') - student_val = Student('John') - - # Put new Strudent with StudentKey. - result = client.sql(insert_query, query_args=[student_key.ID, student_key.DEPT, student_val.NAME]) - assert next(result)[0] == 1 - - # Cache get finds the same value. - studentCache = client.get_cache('StudentCache') - val = studentCache.get(student_key) - assert val is not None - assert val.NAME == student_val.NAME - - query_result = list(client.sql(select_query, include_field_names=True)) - - validate_query_result(student_key, student_val, query_result) - - -def test_python_sql_finds_inserted_value_with_composite_key(client): - """ - Insert a record with a composite key and query it with SELECT SQL. - """ - - client.sql(drop_query) - - # Create table. - result = client.sql(create_query) - assert next(result)[0] == 0 - - student_key = StudentKey(2, 'Business') - student_val = Student('Abe') - - # Put new value using cache. - studentCache = client.get_cache('StudentCache') - studentCache.put(student_key, student_val) - - # Find the value using SQL. - query_result = list(client.sql(select_query, include_field_names=True)) - - validate_query_result(student_key, student_val, query_result) - - -def validate_query_result(student_key, student_val, query_result): - """ - Compare query result with expected key and value. - """ - assert len(query_result) == 2 - sql_row = dict(zip(query_result[0], query_result[1])) - - assert sql_row['ID'] == student_key.ID - assert sql_row['DEPT'] == student_key.DEPT - assert sql_row['NAME'] == student_val.NAME diff --git a/tests/common/test_cache_config.py b/tests/common/test_cache_config.py index b708b0c..f4c8067 100644 --- a/tests/common/test_cache_config.py +++ b/tests/common/test_cache_config.py @@ -12,29 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest -from pyignite.api import * -from pyignite.datatypes.prop_codes import * +from pyignite.datatypes.prop_codes import PROP_NAME, PROP_CACHE_KEY_CONFIGURATION +from pyignite.exceptions import CacheError +cache_name = 'config_cache' -def test_get_configuration(client): - conn = client.random_node - - result = cache_get_or_create(conn, 'my_unique_cache') - assert result.status == 0 - - result = cache_get_configuration(conn, 'my_unique_cache') - assert result.status == 0 - assert result.value[PROP_NAME] == 'my_unique_cache' - - -def test_create_with_config(client): - - cache_name = 'my_very_unique_name' - conn = client.random_node - - result = cache_create_with_config(conn, { +@pytest.fixture +def cache_config(): + return { PROP_NAME: cache_name, PROP_CACHE_KEY_CONFIGURATION: [ { @@ -42,38 +30,86 @@ def test_create_with_config(client): 'affinity_key_field_name': 'abc1234', } ], - }) - assert result.status == 0 + } - result = cache_get_names(conn) - assert cache_name in result.value - result = cache_create_with_config(conn, { - PROP_NAME: cache_name, - }) - assert result.status != 0 +@pytest.fixture +def cache(client): + cache = client.get_or_create_cache(cache_name) + yield cache + cache.destroy() -def test_get_or_create_with_config(client): +@pytest.fixture +async def async_cache(async_client): + cache = await async_client.get_or_create_cache(cache_name) + yield cache + await cache.destroy() - cache_name = 'my_very_unique_name' - conn = client.random_node - result = cache_get_or_create_with_config(conn, { - PROP_NAME: cache_name, - PROP_CACHE_KEY_CONFIGURATION: [ - { - 'type_name': 'blah', - 'affinity_key_field_name': 'abc1234', - } - ], - }) - assert result.status == 0 +@pytest.fixture +def cache_with_config(client, cache_config): + cache = client.get_or_create_cache(cache_config) + yield cache + cache.destroy() - result = cache_get_names(conn) - assert cache_name in result.value - result = cache_get_or_create_with_config(conn, { - PROP_NAME: cache_name, - }) - assert result.status == 0 +@pytest.fixture +async def async_cache_with_config(async_client, cache_config): + cache = await async_client.get_or_create_cache(cache_config) + yield cache + await cache.destroy() + + +def test_cache_get_configuration(client, cache): + assert cache_name in client.get_cache_names() + assert cache.settings[PROP_NAME] == cache_name + + +@pytest.mark.asyncio +async def test_cache_get_configuration_async(async_client, async_cache): + assert cache_name in (await async_client.get_cache_names()) + assert (await async_cache.settings())[PROP_NAME] == cache_name + + +def test_get_or_create_with_config_existing(client, cache_with_config, cache_config): + assert cache_name in client.get_cache_names() + + with pytest.raises(CacheError): + client.create_cache(cache_config) + + cache = client.get_or_create_cache(cache_config) + assert cache.settings == cache_with_config.settings + + +@pytest.mark.asyncio +async def test_get_or_create_with_config_existing_async(async_client, async_cache_with_config, cache_config): + assert cache_name in (await async_client.get_cache_names()) + + with pytest.raises(CacheError): + await async_client.create_cache(cache_config) + + cache = await async_client.get_or_create_cache(cache_config) + assert (await cache.settings()) == (await async_cache_with_config.settings()) + + +def test_get_or_create_with_config_new(client, cache_config): + assert cache_name not in client.get_cache_names() + cache = client.get_or_create_cache(cache_config) + try: + assert cache_name in client.get_cache_names() + assert cache.settings[PROP_NAME] == cache_name + finally: + cache.destroy() + + +@pytest.mark.asyncio +async def test_get_or_create_with_config_new_async(async_client, cache_config): + assert cache_name not in (await async_client.get_cache_names()) + + cache = await async_client.get_or_create_cache(cache_config) + try: + assert cache_name in (await async_client.get_cache_names()) + assert (await cache.settings())[PROP_NAME] == cache_name + finally: + await cache.destroy() diff --git a/tests/common/test_datatypes.py b/tests/common/test_datatypes.py index 83e9a60..c1aa19f 100644 --- a/tests/common/test_datatypes.py +++ b/tests/common/test_datatypes.py @@ -20,199 +20,239 @@ import pytest import uuid -from pyignite.api.key_value import cache_get, cache_put -from pyignite.datatypes import * +from pyignite.datatypes import ( + ByteObject, IntObject, FloatObject, CharObject, ShortObject, BoolObject, ByteArrayObject, IntArrayObject, + ShortArrayObject, FloatArrayObject, BoolArrayObject, CharArrayObject, TimestampObject, String, BinaryEnumObject, + TimestampArrayObject, BinaryEnumArrayObject, ObjectArrayObject, CollectionObject, MapObject +) from pyignite.utils import unsigned +put_get_data_params = [ + # integers + (42, None), + (42, ByteObject), + (42, ShortObject), + (42, IntObject), + + # floats + (3.1415, None), # True for Double but not Float + (3.5, FloatObject), + + # char is never autodetected + ('ы', CharObject), + ('カ', CharObject), + + # bool + (True, None), + (False, None), + (True, BoolObject), + (False, BoolObject), + + # arrays of integers + ([1, 2, 3, 5], None), + (b'buzz', ByteArrayObject), + (bytearray([7, 8, 8, 11]), None), + (bytearray([7, 8, 8, 11]), ByteArrayObject), + ([1, 2, 3, 5], ShortArrayObject), + ([1, 2, 3, 5], IntArrayObject), + + # arrays of floats + ([2.2, 4.4, 6.6], None), + ([2.5, 6.5], FloatArrayObject), + + # array of char + (['ы', 'カ'], CharArrayObject), + + # array of bool + ([True, False, True], None), + ([True, False], BoolArrayObject), + ([False, True], BoolArrayObject), + ([True, False, True, False], BoolArrayObject), + + # string + ('Little Mary had a lamb', None), + ('This is a test', String), + + # decimals + (decimal.Decimal('2.5'), None), + (decimal.Decimal('-1.3'), None), + + # uuid + (uuid.uuid4(), None), + + # date + (datetime(year=1998, month=4, day=6, hour=18, minute=30), None), + + # no autodetection for timestamp either + ( + (datetime(year=1998, month=4, day=6, hour=18, minute=30), 1000), + TimestampObject + ), + + # time + (timedelta(days=4, hours=4, minutes=24), None), + + # enum is useless in Python, except for interoperability with Java. + # Also no autodetection + ((5, 6), BinaryEnumObject), + + # arrays of standard types + (['String 1', 'String 2'], None), + (['Some of us are empty', None, 'But not the others'], None), + + ([decimal.Decimal('2.71828'), decimal.Decimal('100')], None), + ([decimal.Decimal('2.1'), None, decimal.Decimal('3.1415')], None), + + ([uuid.uuid4(), uuid.uuid4()], None), + ( + [ + datetime(year=2010, month=1, day=1), + datetime(year=2010, month=12, day=31), + ], + None, + ), + ([timedelta(minutes=30), timedelta(hours=2)], None), + ( + [ + (datetime(year=2010, month=1, day=1), 1000), + (datetime(year=2010, month=12, day=31), 200), + ], + TimestampArrayObject + ), + ((-1, [(6001, 1), (6002, 2), (6003, 3)]), BinaryEnumArrayObject), + + # object array + ((ObjectArrayObject.OBJECT, [1, 2, decimal.Decimal('3')]), ObjectArrayObject), + + # collection + ((CollectionObject.LINKED_LIST, [1, 2, 3]), None), + + # map + ((MapObject.HASH_MAP, {'key': 4, 5: 6.0}), None), + ((MapObject.LINKED_HASH_MAP, OrderedDict([('key', 4), (5, 6.0)])), None), +] + @pytest.mark.parametrize( 'value, value_hint', - [ - # integers - (42, None), - (42, ByteObject), - (42, ShortObject), - (42, IntObject), - - # floats - (3.1415, None), # True for Double but not Float - (3.5, FloatObject), - - # char is never autodetected - ('ы', CharObject), - ('カ', CharObject), - - # bool - (True, None), - (False, None), - (True, BoolObject), - (False, BoolObject), - - # arrays of integers - ([1, 2, 3, 5], None), - (b'buzz', ByteArrayObject), - (bytearray([7, 8, 8, 11]), None), - (bytearray([7, 8, 8, 11]), ByteArrayObject), - ([1, 2, 3, 5], ShortArrayObject), - ([1, 2, 3, 5], IntArrayObject), - - # arrays of floats - ([2.2, 4.4, 6.6], None), - ([2.5, 6.5], FloatArrayObject), - - # array of char - (['ы', 'カ'], CharArrayObject), - - # array of bool - ([True, False, True], None), - ([True, False], BoolArrayObject), - ([False, True], BoolArrayObject), - ([True, False, True, False], BoolArrayObject), - - # string - ('Little Mary had a lamb', None), - ('This is a test', String), - - # decimals - (decimal.Decimal('2.5'), None), - (decimal.Decimal('-1.3'), None), - - # uuid - (uuid.uuid4(), None), - - # date - (datetime(year=1998, month=4, day=6, hour=18, minute=30), None), - - # no autodetection for timestamp either - ( - (datetime(year=1998, month=4, day=6, hour=18, minute=30), 1000), - TimestampObject - ), - - # time - (timedelta(days=4, hours=4, minutes=24), None), - - # enum is useless in Python, except for interoperability with Java. - # Also no autodetection - ((5, 6), BinaryEnumObject), - - # arrays of standard types - (['String 1', 'String 2'], None), - (['Some of us are empty', None, 'But not the others'], None), - - ([decimal.Decimal('2.71828'), decimal.Decimal('100')], None), - ([decimal.Decimal('2.1'), None, decimal.Decimal('3.1415')], None), - - ([uuid.uuid4(), uuid.uuid4()], None), - ( - [ - datetime(year=2010, month=1, day=1), - datetime(year=2010, month=12, day=31), - ], - None, - ), - ([timedelta(minutes=30), timedelta(hours=2)], None), - ( - [ - (datetime(year=2010, month=1, day=1), 1000), - (datetime(year=2010, month=12, day=31), 200), - ], - TimestampArrayObject - ), - ((-1, [(6001, 1), (6002, 2), (6003, 3)]), BinaryEnumArrayObject), - - # object array - ((ObjectArrayObject.OBJECT, [1, 2, decimal.Decimal('3')]), ObjectArrayObject), - - # collection - ((CollectionObject.LINKED_LIST, [1, 2, 3]), None), - - # map - ((MapObject.HASH_MAP, {'key': 4, 5: 6.0}), None), - ((MapObject.LINKED_HASH_MAP, OrderedDict([('key', 4), (5, 6.0)])), None), - ] + put_get_data_params ) -def test_put_get_data(client, cache, value, value_hint): +def test_put_get_data(cache, value, value_hint): + cache.put('my_key', value, value_hint=value_hint) + assert cache.get('my_key') == value - conn = client.random_node - result = cache_put(conn, cache, 'my_key', value, value_hint=value_hint) - assert result.status == 0 +@pytest.mark.parametrize( + 'value, value_hint', + put_get_data_params +) +@pytest.mark.asyncio +async def test_put_get_data_async(async_cache, value, value_hint): + await async_cache.put('my_key', value, value_hint=value_hint) + assert await async_cache.get('my_key') == value - result = cache_get(conn, cache, 'my_key') - assert result.status == 0 - assert result.value == value - if isinstance(result.value, list): - for res, val in zip(result.value, value): - assert type(res) == type(val) +bytearray_params = [ + [1, 2, 3, 5], + (7, 8, 13, 18), + (-128, -1, 0, 1, 127, 255), +] @pytest.mark.parametrize( 'value', - [ - [1, 2, 3, 5], - (7, 8, 13, 18), - (-128, -1, 0, 1, 127, 255), - ] + bytearray_params ) -def test_bytearray_from_list_or_tuple(client, cache, value): +def test_bytearray_from_list_or_tuple(cache, value): """ ByteArrayObject's pythonic type is `bytearray`, but it should also accept lists or tuples as a content. """ - conn = client.random_node + cache.put('my_key', value, value_hint=ByteArrayObject) + + assert cache.get('my_key') == bytearray([unsigned(ch, ctypes.c_ubyte) for ch in value]) + + +@pytest.mark.parametrize( + 'value', + bytearray_params +) +@pytest.mark.asyncio +async def test_bytearray_from_list_or_tuple_async(async_cache, value): + """ + ByteArrayObject's pythonic type is `bytearray`, but it should also accept + lists or tuples as a content. + """ + + await async_cache.put('my_key', value, value_hint=ByteArrayObject) + + result = await async_cache.get('my_key') + assert result == bytearray([unsigned(ch, ctypes.c_ubyte) for ch in value]) - result = cache_put( - conn, cache, 'my_key', value, value_hint=ByteArrayObject - ) - assert result.status == 0 - result = cache_get(conn, cache, 'my_key') - assert result.status == 0 - assert result.value == bytearray([ - unsigned(ch, ctypes.c_ubyte) for ch in value - ]) +uuid_params = [ + 'd57babad-7bc1-4c82-9f9c-e72841b92a85', + '5946c0c0-2b76-479d-8694-a2e64a3968da', + 'a521723d-ad5d-46a6-94ad-300f850ef704', +] + +uuid_table_create_sql = "CREATE TABLE test_uuid_repr (id INTEGER PRIMARY KEY, uuid_field UUID)" +uuid_table_drop_sql = "DROP TABLE test_uuid_repr IF EXISTS" +uuid_table_insert_sql = "INSERT INTO test_uuid_repr(id, uuid_field) VALUES (?, ?)" +uuid_table_query_sql = "SELECT * FROM test_uuid_repr WHERE uuid_field=?" + + +@pytest.fixture() +async def uuid_table(client): + client.sql(uuid_table_drop_sql) + client.sql(uuid_table_create_sql) + yield None + client.sql(uuid_table_drop_sql) + + +@pytest.fixture() +async def uuid_table_async(async_client): + await async_client.sql(uuid_table_drop_sql) + await async_client.sql(uuid_table_create_sql) + yield None + await async_client.sql(uuid_table_drop_sql) @pytest.mark.parametrize( 'uuid_string', - [ - 'd57babad-7bc1-4c82-9f9c-e72841b92a85', - '5946c0c0-2b76-479d-8694-a2e64a3968da', - 'a521723d-ad5d-46a6-94ad-300f850ef704', - ] + uuid_params ) -def test_uuid_representation(client, uuid_string): +def test_uuid_representation(client, uuid_string, uuid_table): """ Test if textual UUID representation is correct. """ uuid_value = uuid.UUID(uuid_string) - # initial cleanup - client.sql("DROP TABLE test_uuid_repr IF EXISTS") - # create table with UUID field - client.sql( - "CREATE TABLE test_uuid_repr (id INTEGER PRIMARY KEY, uuid_field UUID)" - ) # use uuid.UUID class to insert data - client.sql( - "INSERT INTO test_uuid_repr(id, uuid_field) VALUES (?, ?)", - query_args=[1, uuid_value] - ) + client.sql(uuid_table_insert_sql, query_args=[1, uuid_value]) # use hex string to retrieve data - result = client.sql( - "SELECT * FROM test_uuid_repr WHERE uuid_field='{}'".format( - uuid_string - ) - ) - - # finalize query - result = list(result) - - # final cleanup - client.sql("DROP TABLE test_uuid_repr IF EXISTS") - - # if a line was retrieved, our test was successful - assert len(result) == 1 - # doublecheck - assert result[0][1] == uuid_value + with client.sql(uuid_table_query_sql, query_args=[str(uuid_value)]) as cursor: + result = list(cursor) + + # if a line was retrieved, our test was successful + assert len(result) == 1 + assert result[0][1] == uuid_value + + +@pytest.mark.parametrize( + 'uuid_string', + uuid_params +) +@pytest.mark.asyncio +async def test_uuid_representation_async(async_client, uuid_string, uuid_table_async): + """ Test if textual UUID representation is correct. """ + uuid_value = uuid.UUID(uuid_string) + + # use uuid.UUID class to insert data + await async_client.sql(uuid_table_insert_sql, query_args=[1, uuid_value]) + # use hex string to retrieve data + async with async_client.sql(uuid_table_query_sql, query_args=[str(uuid_value)]) as cursor: + result = [row async for row in cursor] + + # if a line was retrieved, our test was successful + assert len(result) == 1 + assert result[0][1] == uuid_value diff --git a/tests/common/test_generic_object.py b/tests/common/test_generic_object.py index 73dc870..d6c0ee1 100644 --- a/tests/common/test_generic_object.py +++ b/tests/common/test_generic_object.py @@ -14,11 +14,10 @@ # limitations under the License. from pyignite import GenericObjectMeta -from pyignite.datatypes import * +from pyignite.datatypes import IntObject, String def test_go(): - class GenericObject( metaclass=GenericObjectMeta, schema={ diff --git a/tests/common/test_get_names.py b/tests/common/test_get_names.py index 2d6c0bc..7fcb499 100644 --- a/tests/common/test_get_names.py +++ b/tests/common/test_get_names.py @@ -12,21 +12,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio -from pyignite.api import cache_create, cache_get_names +import pytest def test_get_names(client): + bucket_names = {'my_bucket', 'my_bucket_2', 'my_bucket_3'} + for name in bucket_names: + client.get_or_create_cache(name) - conn = client.random_node + assert set(client.get_cache_names()) == bucket_names - bucket_names = ['my_bucket', 'my_bucket_2', 'my_bucket_3'] - for name in bucket_names: - cache_create(conn, name) - result = cache_get_names(conn) - assert result.status == 0 - assert type(result.value) == list - assert len(result.value) >= len(bucket_names) - for i, name in enumerate(bucket_names): - assert name in result.value +@pytest.mark.asyncio +async def test_get_names_async(async_client): + bucket_names = {'my_bucket', 'my_bucket_2', 'my_bucket_3'} + await asyncio.gather(*[async_client.get_or_create_cache(name) for name in bucket_names]) + + assert set(await async_client.get_cache_names()) == bucket_names diff --git a/tests/common/test_key_value.py b/tests/common/test_key_value.py index a7edce1..0f492a2 100644 --- a/tests/common/test_key_value.py +++ b/tests/common/test_key_value.py @@ -15,426 +15,405 @@ from datetime import datetime -from pyignite.api import * -from pyignite.datatypes import ( - CollectionObject, IntObject, MapObject, TimestampObject, -) +import pytest +from pyignite.datatypes import CollectionObject, IntObject, MapObject, TimestampObject -def test_put_get(client, cache): - conn = client.random_node +def test_put_get(cache): + cache.put('my_key', 5) - result = cache_put(conn, cache, 'my_key', 5) - assert result.status == 0 + assert cache.get('my_key') == 5 - result = cache_get(conn, cache, 'my_key') - assert result.status == 0 - assert result.value == 5 +@pytest.mark.asyncio +async def test_put_get_async(async_cache): + await async_cache.put('my_key', 5) -def test_get_all(client, cache): + assert await async_cache.get('my_key') == 5 - conn = client.random_node - result = cache_get_all(conn, cache, ['key_1', 2, (3, IntObject)]) - assert result.status == 0 - assert result.value == {} +def test_get_all(cache): + assert cache.get_all(['key_1', 2, (3, IntObject)]) == {} - cache_put(conn, cache, 'key_1', 4) - cache_put(conn, cache, 3, 18, key_hint=IntObject) + cache.put('key_1', 4) + cache.put(3, 18, key_hint=IntObject) - result = cache_get_all(conn, cache, ['key_1', 2, (3, IntObject)]) - assert result.status == 0 - assert result.value == {'key_1': 4, 3: 18} + assert cache.get_all(['key_1', 2, (3, IntObject)]) == {'key_1': 4, 3: 18} -def test_put_all(client, cache): +@pytest.mark.asyncio +async def test_get_all_async(async_cache): + assert await async_cache.get_all(['key_1', 2, (3, IntObject)]) == {} - conn = client.random_node + await async_cache.put('key_1', 4) + await async_cache.put(3, 18, key_hint=IntObject) + assert await async_cache.get_all(['key_1', 2, (3, IntObject)]) == {'key_1': 4, 3: 18} + + +def test_put_all(cache): test_dict = { 1: 2, 'key_1': 4, (3, IntObject): 18, } - test_keys = ['key_1', 1, 3] - - result = cache_put_all(conn, cache, test_dict) - assert result.status == 0 - - result = cache_get_all(conn, cache, test_keys) - assert result.status == 0 - assert len(test_dict) == 3 - - for key in result.value: - assert key in test_keys - - -def test_contains_key(client, cache): - - conn = client.random_node - - cache_put(conn, cache, 'test_key', 42) - - result = cache_contains_key(conn, cache, 'test_key') - assert result.value is True - - result = cache_contains_key(conn, cache, 'non-existant-key') - assert result.value is False - - -def test_contains_keys(client, cache): - - conn = client.random_node - - cache_put(conn, cache, 5, 6) - cache_put(conn, cache, 'test_key', 42) + cache.put_all(test_dict) - result = cache_contains_keys(conn, cache, [5, 'test_key']) - assert result.value is True + result = cache.get_all(list(test_dict.keys())) - result = cache_contains_keys(conn, cache, [5, 'non-existent-key']) - assert result.value is False + assert len(result) == len(test_dict) + for k, v in test_dict.items(): + k = k[0] if isinstance(k, tuple) else k + assert result[k] == v -def test_get_and_put(client, cache): - - conn = client.random_node - - result = cache_get_and_put(conn, cache, 'test_key', 42) - assert result.status == 0 - assert result.value is None - - result = cache_get(conn, cache, 'test_key') - assert result.status == 0 - assert result.value is 42 - - result = cache_get_and_put(conn, cache, 'test_key', 1234) - assert result.status == 0 - assert result.value == 42 - - -def test_get_and_replace(client, cache): - - conn = client.random_node - - result = cache_get_and_replace(conn, cache, 'test_key', 42) - assert result.status == 0 - assert result.value is None - - result = cache_get(conn, cache, 'test_key') - assert result.status == 0 - assert result.value is None - - cache_put(conn, cache, 'test_key', 42) - - result = cache_get_and_replace(conn, cache, 'test_key', 1234) - assert result.status == 0 - assert result.value == 42 - +@pytest.mark.asyncio +async def test_put_all_async(async_cache): + test_dict = { + 1: 2, + 'key_1': 4, + (3, IntObject): 18, + } + await async_cache.put_all(test_dict) -def test_get_and_remove(client, cache): + result = await async_cache.get_all(list(test_dict.keys())) - conn = client.random_node + assert len(result) == len(test_dict) + for k, v in test_dict.items(): + k = k[0] if isinstance(k, tuple) else k + assert result[k] == v - result = cache_get_and_remove(conn, cache, 'test_key') - assert result.status == 0 - assert result.value is None - cache_put(conn, cache, 'test_key', 42) +def test_contains_key(cache): + cache.put('test_key', 42) - result = cache_get_and_remove(conn, cache, 'test_key') - assert result.status == 0 - assert result.value == 42 + assert cache.contains_key('test_key') + assert not cache.contains_key('non-existent-key') -def test_put_if_absent(client, cache): +@pytest.mark.asyncio +async def test_contains_key_async(async_cache): + await async_cache.put('test_key', 42) - conn = client.random_node + assert await async_cache.contains_key('test_key') + assert not await async_cache.contains_key('non-existent-key') - result = cache_put_if_absent(conn, cache, 'test_key', 42) - assert result.status == 0 - assert result.value is True - result = cache_put_if_absent(conn, cache, 'test_key', 1234) - assert result.status == 0 - assert result.value is False +def test_contains_keys(cache): + cache.put(5, 6) + cache.put('test_key', 42) + assert cache.contains_keys([5, 'test_key']) + assert not cache.contains_keys([5, 'non-existent-key']) -def test_get_and_put_if_absent(client, cache): - conn = client.random_node +@pytest.mark.asyncio +async def test_contains_keys_async(async_cache): + await async_cache.put(5, 6) + await async_cache.put('test_key', 42) - result = cache_get_and_put_if_absent(conn, cache, 'test_key', 42) - assert result.status == 0 - assert result.value is None + assert await async_cache.contains_keys([5, 'test_key']) + assert not await async_cache.contains_keys([5, 'non-existent-key']) - result = cache_get_and_put_if_absent(conn, cache, 'test_key', 1234) - assert result.status == 0 - assert result.value == 42 - result = cache_get_and_put_if_absent(conn, cache, 'test_key', 5678) - assert result.status == 0 - assert result.value == 42 +def test_get_and_put(cache): + assert cache.get_and_put('test_key', 42) is None + assert cache.get('test_key') == 42 + assert cache.get_and_put('test_key', 1234) == 42 + assert cache.get('test_key') == 1234 -def test_replace(client, cache): +@pytest.mark.asyncio +async def test_get_and_put_async(async_cache): + assert await async_cache.get_and_put('test_key', 42) is None + assert await async_cache.get('test_key') == 42 + assert await async_cache.get_and_put('test_key', 1234) == 42 + assert await async_cache.get('test_key') == 1234 - conn = client.random_node - result = cache_replace(conn, cache, 'test_key', 42) - assert result.status == 0 - assert result.value is False +def test_get_and_replace(cache): + assert cache.get_and_replace('test_key', 42) is None + assert cache.get('test_key') is None + cache.put('test_key', 42) + assert cache.get_and_replace('test_key', 1234) == 42 - cache_put(conn, cache, 'test_key', 1234) - result = cache_replace(conn, cache, 'test_key', 42) - assert result.status == 0 - assert result.value is True +@pytest.mark.asyncio +async def test_get_and_replace_async(async_cache): + assert await async_cache.get_and_replace('test_key', 42) is None + assert await async_cache.get('test_key') is None + await async_cache.put('test_key', 42) + assert await async_cache.get_and_replace('test_key', 1234) == 42 - result = cache_get(conn, cache, 'test_key') - assert result.status == 0 - assert result.value == 42 +def test_get_and_remove(cache): + assert cache.get_and_remove('test_key') is None + cache.put('test_key', 42) + assert cache.get_and_remove('test_key') == 42 + assert cache.get_and_remove('test_key') is None -def test_replace_if_equals(client, cache): - conn = client.random_node +@pytest.mark.asyncio +async def test_get_and_remove_async(async_cache): + assert await async_cache.get_and_remove('test_key') is None + await async_cache.put('test_key', 42) + assert await async_cache.get_and_remove('test_key') == 42 + assert await async_cache.get_and_remove('test_key') is None - result = cache_replace_if_equals(conn, cache, 'my_test', 42, 1234) - assert result.status == 0 - assert result.value is False - cache_put(conn, cache, 'my_test', 42) +def test_put_if_absent(cache): + assert cache.put_if_absent('test_key', 42) + assert not cache.put_if_absent('test_key', 1234) - result = cache_replace_if_equals(conn, cache, 'my_test', 42, 1234) - assert result.status == 0 - assert result.value is True - result = cache_get(conn, cache, 'my_test') - assert result.status == 0 - assert result.value == 1234 +@pytest.mark.asyncio +async def test_put_if_absent_async(async_cache): + assert await async_cache.put_if_absent('test_key', 42) + assert not await async_cache.put_if_absent('test_key', 1234) -def test_clear(client, cache): +def test_get_and_put_if_absent(cache): + assert cache.get_and_put_if_absent('test_key', 42) is None + assert cache.get_and_put_if_absent('test_key', 1234) == 42 + assert cache.get_and_put_if_absent('test_key', 5678) == 42 + assert cache.get('test_key') == 42 - conn = client.random_node - result = cache_put(conn, cache, 'my_test', 42) - assert result.status == 0 +@pytest.mark.asyncio +async def test_get_and_put_if_absent_async(async_cache): + assert await async_cache.get_and_put_if_absent('test_key', 42) is None + assert await async_cache.get_and_put_if_absent('test_key', 1234) == 42 + assert await async_cache.get_and_put_if_absent('test_key', 5678) == 42 + assert await async_cache.get('test_key') == 42 - result = cache_clear(conn, cache) - assert result.status == 0 - result = cache_get(conn, cache, 'my_test') - assert result.status == 0 - assert result.value is None +def test_replace(cache): + assert cache.replace('test_key', 42) is False + cache.put('test_key', 1234) + assert cache.replace('test_key', 42) is True + assert cache.get('test_key') == 42 -def test_clear_key(client, cache): +@pytest.mark.asyncio +async def test_replace_async(async_cache): + assert await async_cache.replace('test_key', 42) is False + await async_cache.put('test_key', 1234) + assert await async_cache.replace('test_key', 42) is True + assert await async_cache.get('test_key') == 42 - conn = client.random_node - result = cache_put(conn, cache, 'my_test', 42) - assert result.status == 0 +def test_replace_if_equals(cache): + assert cache.replace_if_equals('my_test', 42, 1234) is False + cache.put('my_test', 42) + assert cache.replace_if_equals('my_test', 42, 1234) is True + assert cache.get('my_test') == 1234 - result = cache_put(conn, cache, 'another_test', 24) - assert result.status == 0 - result = cache_clear_key(conn, cache, 'my_test') - assert result.status == 0 +@pytest.mark.asyncio +async def test_replace_if_equals_async(async_cache): + assert await async_cache.replace_if_equals('my_test', 42, 1234) is False + await async_cache.put('my_test', 42) + assert await async_cache.replace_if_equals('my_test', 42, 1234) is True + assert await async_cache.get('my_test') == 1234 - result = cache_get(conn, cache, 'my_test') - assert result.status == 0 - assert result.value is None - result = cache_get(conn, cache, 'another_test') - assert result.status == 0 - assert result.value == 24 +def test_clear(cache): + cache.put('my_test', 42) + cache.clear() + assert cache.get('my_test') is None -def test_clear_keys(client, cache): +@pytest.mark.asyncio +async def test_clear_async(async_cache): + await async_cache.put('my_test', 42) + await async_cache.clear() + assert await async_cache.get('my_test') is None - conn = client.random_node - result = cache_put(conn, cache, 'my_test_key', 42) - assert result.status == 0 +def test_clear_key(cache): + cache.put('my_test', 42) + cache.put('another_test', 24) - result = cache_put(conn, cache, 'another_test', 24) - assert result.status == 0 + cache.clear_key('my_test') - result = cache_clear_keys(conn, cache, [ - 'my_test_key', - 'nonexistent_key', - ]) - assert result.status == 0 + assert cache.get('my_test') is None + assert cache.get('another_test') == 24 - result = cache_get(conn, cache, 'my_test_key') - assert result.status == 0 - assert result.value is None - result = cache_get(conn, cache, 'another_test') - assert result.status == 0 - assert result.value == 24 +@pytest.mark.asyncio +async def test_clear_key_async(async_cache): + await async_cache.put('my_test', 42) + await async_cache.put('another_test', 24) + await async_cache.clear_key('my_test') -def test_remove_key(client, cache): + assert await async_cache.get('my_test') is None + assert await async_cache.get('another_test') == 24 - conn = client.random_node - result = cache_put(conn, cache, 'my_test_key', 42) - assert result.status == 0 +def test_clear_keys(cache): + cache.put('my_test_key', 42) + cache.put('another_test', 24) - result = cache_remove_key(conn, cache, 'my_test_key') - assert result.status == 0 - assert result.value is True + cache.clear_keys(['my_test_key', 'nonexistent_key']) - result = cache_remove_key(conn, cache, 'non_existent_key') - assert result.status == 0 - assert result.value is False + assert cache.get('my_test_key') is None + assert cache.get('another_test') == 24 -def test_remove_if_equals(client, cache): +@pytest.mark.asyncio +async def test_clear_keys_async(async_cache): + await async_cache.put('my_test_key', 42) + await async_cache.put('another_test', 24) - conn = client.random_node + await async_cache.clear_keys(['my_test_key', 'nonexistent_key']) - result = cache_put(conn, cache, 'my_test', 42) - assert result.status == 0 + assert await async_cache.get('my_test_key') is None + assert await async_cache.get('another_test') == 24 - result = cache_remove_if_equals(conn, cache, 'my_test', 1234) - assert result.status == 0 - assert result.value is False - result = cache_remove_if_equals(conn, cache, 'my_test', 42) - assert result.status == 0 - assert result.value is True +def test_remove_key(cache): + cache.put('my_test_key', 42) + assert cache.remove_key('my_test_key') is True + assert cache.remove_key('non_existent_key') is False - result = cache_get(conn, cache, 'my_test') - assert result.status == 0 - assert result.value is None +@pytest.mark.asyncio +async def test_remove_key_async(async_cache): + await async_cache.put('my_test_key', 42) + assert await async_cache.remove_key('my_test_key') is True + assert await async_cache.remove_key('non_existent_key') is False -def test_remove_keys(client, cache): - conn = client.random_node +def test_remove_if_equals(cache): + cache.put('my_test', 42) + assert cache.remove_if_equals('my_test', 1234) is False + assert cache.remove_if_equals('my_test', 42) is True + assert cache.get('my_test') is None - result = cache_put(conn, cache, 'my_test', 42) - assert result.status == 0 - result = cache_put(conn, cache, 'another_test', 24) - assert result.status == 0 +@pytest.mark.asyncio +async def test_remove_if_equals_async(async_cache): + await async_cache.put('my_test', 42) + assert await async_cache.remove_if_equals('my_test', 1234) is False + assert await async_cache.remove_if_equals('my_test', 42) is True + assert await async_cache.get('my_test') is None - result = cache_remove_keys(conn, cache, ['my_test', 'non_existent']) - assert result.status == 0 - result = cache_get(conn, cache, 'my_test') - assert result.status == 0 - assert result.value is None +def test_remove_keys(cache): + cache.put('my_test', 42) - result = cache_get(conn, cache, 'another_test') - assert result.status == 0 - assert result.value == 24 + cache.put('another_test', 24) + cache.remove_keys(['my_test', 'non_existent']) + assert cache.get('my_test') is None + assert cache.get('another_test') == 24 -def test_remove_all(client, cache): - conn = client.random_node +@pytest.mark.asyncio +async def test_remove_keys_async(async_cache): + await async_cache.put('my_test', 42) - result = cache_put(conn, cache, 'my_test', 42) - assert result.status == 0 + await async_cache.put('another_test', 24) + await async_cache.remove_keys(['my_test', 'non_existent']) - result = cache_put(conn, cache, 'another_test', 24) - assert result.status == 0 + assert await async_cache.get('my_test') is None + assert await async_cache.get('another_test') == 24 - result = cache_remove_all(conn, cache) - assert result.status == 0 - result = cache_get(conn, cache, 'my_test') - assert result.status == 0 - assert result.value is None +def test_remove_all(cache): + cache.put('my_test', 42) + cache.put('another_test', 24) + cache.remove_all() - result = cache_get(conn, cache, 'another_test') - assert result.status == 0 - assert result.value is None + assert cache.get('my_test') is None + assert cache.get('another_test') is None -def test_cache_get_size(client, cache): +@pytest.mark.asyncio +async def test_remove_all_async(async_cache): + await async_cache.put('my_test', 42) + await async_cache.put('another_test', 24) + await async_cache.remove_all() - conn = client.random_node + assert await async_cache.get('my_test') is None + assert await async_cache.get('another_test') is None - result = cache_put(conn, cache, 'my_test', 42) - assert result.status == 0 - result = cache_get_size(conn, cache) - assert result.status == 0 - assert result.value == 1 +def test_cache_get_size(cache): + cache.put('my_test', 42) + assert cache.get_size() == 1 -def test_put_get_collection(client): +@pytest.mark.asyncio +async def test_cache_get_size_async(async_cache): + await async_cache.put('my_test', 42) + assert await async_cache.get_size() == 1 - test_datetime = datetime(year=1996, month=3, day=1) - cache = client.get_or_create_cache('test_coll_cache') - cache.put( +collection_params = [ + [ 'simple', - ( - 1, - [ - (123, IntObject), - 678, - None, - 55.2, - ((test_datetime, 0), TimestampObject), - ] - ), - value_hint=CollectionObject - ) - value = cache.get('simple') - assert value == (1, [123, 678, None, 55.2, (test_datetime, 0)]) - - cache.put( + (1, [(123, IntObject), 678, None, 55.2, ((datetime(year=1996, month=3, day=1), 0), TimestampObject)]), + (1, [123, 678, None, 55.2, (datetime(year=1996, month=3, day=1), 0)]) + ], + [ 'nested', - ( - 1, - [ - 123, - ((1, [456, 'inner_test_string', 789]), CollectionObject), - 'outer_test_string', - ] - ), - value_hint=CollectionObject - ) - value = cache.get('nested') - assert value == ( - 1, - [ - 123, - (1, [456, 'inner_test_string', 789]), - 'outer_test_string' - ] - ) - - -def test_put_get_map(client): - - cache = client.get_or_create_cache('test_map_cache') - - cache.put( - 'test_map', + (1, [123, ((1, [456, 'inner_test_string', 789]), CollectionObject), 'outer_test_string']), + (1, [123, (1, [456, 'inner_test_string', 789]), 'outer_test_string']) + ], + [ + 'hash_map', ( MapObject.HASH_MAP, { (123, IntObject): 'test_data', 456: ((1, [456, 'inner_test_string', 789]), CollectionObject), 'test_key': 32.4, + 'simple_strings': ['string_1', 'string_2'] + } + ), + ( + MapObject.HASH_MAP, + { + 123: 'test_data', + 456: (1, [456, 'inner_test_string', 789]), + 'test_key': 32.4, + 'simple_strings': ['string_1', 'string_2'] + } + ) + ], + [ + 'linked_hash_map', + ( + MapObject.LINKED_HASH_MAP, + { + 'test_data': 12345, + 456: ['string_1', 'string_2'], + 'test_key': 32.4 } ), - value_hint=MapObject - ) - value = cache.get('test_map') - assert value == (MapObject.HASH_MAP, { - 123: 'test_data', - 456: (1, [456, 'inner_test_string', 789]), - 'test_key': 32.4, - }) + ( + MapObject.LINKED_HASH_MAP, + { + 'test_data': 12345, + 456: ['string_1', 'string_2'], + 'test_key': 32.4 + } + ) + ], +] + + +@pytest.mark.parametrize(['key', 'hinted_value', 'value'], collection_params) +def test_put_get_collection(cache, key, hinted_value, value): + cache.put(key, hinted_value) + assert cache.get(key) == value + + +@pytest.mark.parametrize(['key', 'hinted_value', 'value'], collection_params) +@pytest.mark.asyncio +async def test_put_get_collection_async(async_cache, key, hinted_value, value): + await async_cache.put(key, hinted_value) + assert await async_cache.get(key) == value diff --git a/tests/common/test_scan.py b/tests/common/test_scan.py index 2f0e056..d55fd3e 100644 --- a/tests/common/test_scan.py +++ b/tests/common/test_scan.py @@ -12,57 +12,153 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections import OrderedDict -from pyignite.api import ( - scan, scan_cursor_get_page, resource_close, cache_put_all, -) +import pytest +from pyignite import GenericObjectMeta +from pyignite.api import resource_close, resource_close_async +from pyignite.connection import AioConnection +from pyignite.datatypes import IntObject, String +from pyignite.exceptions import CacheError -def test_scan(client, cache): - conn = client.random_node - page_size = 10 +class SimpleObject( + metaclass=GenericObjectMeta, + type_name='SimpleObject', + schema=OrderedDict([ + ('id', IntObject), + ('str', String), + ]) +): + pass - result = cache_put_all(conn, cache, { - 'key_{}'.format(v): v for v in range(page_size * 2) - }) - assert result.status == 0 - result = scan(conn, cache, page_size) - assert result.status == 0 - assert len(result.value['data']) == page_size - assert result.value['more'] is True +page_size = 10 - cursor = result.value['cursor'] - result = scan_cursor_get_page(conn, cursor) - assert result.status == 0 - assert len(result.value['data']) == page_size - assert result.value['more'] is False +@pytest.fixture +def test_objects_data(): + yield {i: SimpleObject(id=i, str=f'str_{i}') for i in range(page_size * 2)} - result = scan_cursor_get_page(conn, cursor) - assert result.status != 0 +@pytest.mark.asyncio +def test_scan_objects(cache, test_objects_data): + cache.put_all(test_objects_data) -def test_close_resource(client, cache): + for p_sz in [page_size, page_size * 2, page_size * 3, page_size + 5]: + with cache.scan(p_sz) as cursor: + result = {k: v for k, v in cursor} + assert result == test_objects_data - conn = client.random_node - page_size = 10 + __check_cursor_closed(cursor) - result = cache_put_all(conn, cache, { - 'key_{}'.format(v): v for v in range(page_size * 2) - }) - assert result.status == 0 + with pytest.raises(Exception): + with cache.scan(p_sz) as cursor: + for _ in cursor: + raise Exception - result = scan(conn, cache, page_size) - assert result.status == 0 - assert len(result.value['data']) == page_size - assert result.value['more'] is True + __check_cursor_closed(cursor) - cursor = result.value['cursor'] + cursor = cache.scan(page_size) + assert {k: v for k, v in cursor} == test_objects_data + __check_cursor_closed(cursor) - result = resource_close(conn, cursor) - assert result.status == 0 - result = scan_cursor_get_page(conn, cursor) - assert result.status != 0 +@pytest.mark.asyncio +async def test_scan_objects_async(async_cache, test_objects_data): + await async_cache.put_all(test_objects_data) + + for p_sz in [page_size, page_size * 2, page_size * 3, page_size + 5]: + async with async_cache.scan(p_sz) as cursor: + result = {k: v async for k, v in cursor} + assert result == test_objects_data + + await __check_cursor_closed(cursor) + + with pytest.raises(Exception): + async with async_cache.scan(p_sz) as cursor: + async for _ in cursor: + raise Exception + + await __check_cursor_closed(cursor) + + cursor = await async_cache.scan(page_size) + assert {k: v async for k, v in cursor} == test_objects_data + + await __check_cursor_closed(cursor) + + +@pytest.fixture +def cache_scan_data(): + yield { + 1: 'This is a test', + 2: 'One more test', + 3: 'Foo', + 4: 'Buzz', + 5: 'Bar', + 6: 'Lorem ipsum', + 7: 'dolor sit amet', + 8: 'consectetur adipiscing elit', + 9: 'Nullam aliquet', + 10: 'nisl at ante', + 11: 'suscipit', + 12: 'ut cursus', + 13: 'metus interdum', + 14: 'Nulla tincidunt', + 15: 'sollicitudin iaculis', + } + + +@pytest.mark.parametrize('page_size', range(1, 17, 5)) +def test_cache_scan(cache, cache_scan_data, page_size): + cache.put_all(cache_scan_data) + + with cache.scan(page_size=page_size) as cursor: + assert {k: v for k, v in cursor} == cache_scan_data + + +@pytest.mark.parametrize('page_size', range(1, 17, 5)) +@pytest.mark.asyncio +async def test_cache_scan_async(async_cache, cache_scan_data, page_size): + await async_cache.put_all(cache_scan_data) + + async with async_cache.scan(page_size=page_size) as cursor: + assert {k: v async for k, v in cursor} == cache_scan_data + + +def test_uninitialized_cursor(cache, test_objects_data): + cache.put_all(test_objects_data) + + cursor = cache.scan(page_size) + for _ in cursor: + break + + cursor.close() + __check_cursor_closed(cursor) + + +@pytest.mark.asyncio +async def test_uninitialized_cursor_async(async_cache, test_objects_data): + await async_cache.put_all(test_objects_data) + + # iterating of non-awaited cursor. + with pytest.raises(CacheError): + cursor = async_cache.scan(page_size) + assert {k: v async for k, v in cursor} == test_objects_data + + cursor = await async_cache.scan(page_size) + assert {k: v async for k, v in cursor} == test_objects_data + await __check_cursor_closed(cursor) + + +def __check_cursor_closed(cursor): + async def check_async(): + result = await resource_close_async(cursor.connection, cursor.cursor_id) + assert result.status != 0 + + def check(): + result = resource_close(cursor.connection, cursor.cursor_id) + assert result.status != 0 + + return check_async() if isinstance(cursor.connection, AioConnection) else check() diff --git a/tests/common/test_sql.py b/tests/common/test_sql.py index cc68a02..0841b7f 100644 --- a/tests/common/test_sql.py +++ b/tests/common/test_sql.py @@ -15,160 +15,173 @@ import pytest -from pyignite.api import ( - sql_fields, sql_fields_cursor_get_page, - sql, sql_cursor_get_page, - cache_get_configuration, -) +from pyignite import AioClient +from pyignite.aio_cache import AioCache from pyignite.datatypes.cache_config import CacheMode -from pyignite.datatypes.prop_codes import * +from pyignite.datatypes.prop_codes import PROP_NAME, PROP_SQL_SCHEMA, PROP_QUERY_ENTITIES, PROP_CACHE_MODE from pyignite.exceptions import SQLError from pyignite.utils import entity_id -from pyignite.binary import unwrap_binary - -initial_data = [ - ('John', 'Doe', 5), - ('Jane', 'Roe', 4), - ('Joe', 'Bloggs', 4), - ('Richard', 'Public', 3), - ('Negidius', 'Numerius', 3), - ] -create_query = '''CREATE TABLE Student ( - id INT(11) PRIMARY KEY, - first_name CHAR(24), - last_name CHAR(32), - grade INT(11))''' - -insert_query = '''INSERT INTO Student(id, first_name, last_name, grade) -VALUES (?, ?, ?, ?)''' - -select_query = 'SELECT id, first_name, last_name, grade FROM Student' - -drop_query = 'DROP TABLE Student IF EXISTS' - -page_size = 4 - - -def test_sql(client): - - conn = client.random_node - - # cleanup - client.sql(drop_query) - - result = sql_fields( - conn, - 0, - create_query, - page_size, - schema='PUBLIC', - include_field_names=True - ) - assert result.status == 0, result.message - - for i, data_line in enumerate(initial_data, start=1): - fname, lname, grade = data_line - result = sql_fields( - conn, - 0, - insert_query, - page_size, - schema='PUBLIC', - query_args=[i, fname, lname, grade], - include_field_names=True - ) - assert result.status == 0, result.message - - result = cache_get_configuration(conn, 'SQL_PUBLIC_STUDENT') - assert result.status == 0, result.message - - binary_type_name = result.value[PROP_QUERY_ENTITIES][0]['value_type_name'] - result = sql( - conn, - 'SQL_PUBLIC_STUDENT', - binary_type_name, - 'TRUE', - page_size - ) - assert result.status == 0, result.message - assert len(result.value['data']) == page_size - assert result.value['more'] is True - - for wrapped_object in result.value['data'].values(): - data = unwrap_binary(client, wrapped_object) - assert data.type_id == entity_id(binary_type_name) - - cursor = result.value['cursor'] - - while result.value['more']: - result = sql_cursor_get_page(conn, cursor) - assert result.status == 0, result.message - - for wrapped_object in result.value['data'].values(): - data = unwrap_binary(client, wrapped_object) - assert data.type_id == entity_id(binary_type_name) - - # repeat cleanup - result = sql_fields(conn, 0, drop_query, page_size, schema='PUBLIC') - assert result.status == 0 - - -def test_sql_fields(client): - - conn = client.random_node - - # cleanup - client.sql(drop_query) - - result = sql_fields( - conn, - 0, - create_query, - page_size, - schema='PUBLIC', - include_field_names=True - ) - assert result.status == 0, result.message - - for i, data_line in enumerate(initial_data, start=1): - fname, lname, grade = data_line - result = sql_fields( - conn, - 0, - insert_query, - page_size, - schema='PUBLIC', - query_args=[i, fname, lname, grade], - include_field_names=True - ) - assert result.status == 0, result.message - - result = sql_fields( - conn, - 0, - select_query, - page_size, - schema='PUBLIC', - include_field_names=True - ) - assert result.status == 0 - assert len(result.value['data']) == page_size - assert result.value['more'] is True - - cursor = result.value['cursor'] - - result = sql_fields_cursor_get_page(conn, cursor, field_count=4) - assert result.status == 0 - assert len(result.value['data']) == len(initial_data) - page_size - assert result.value['more'] is False - - # repeat cleanup - result = sql_fields(conn, 0, drop_query, page_size, schema='PUBLIC') - assert result.status == 0 - - -def test_long_multipage_query(client): +student_table_data = [ + ('John', 'Doe', 5), + ('Jane', 'Roe', 4), + ('Joe', 'Bloggs', 4), + ('Richard', 'Public', 3), + ('Negidius', 'Numerius', 3), +] + +student_table_select_query = 'SELECT id, first_name, last_name, grade FROM Student ORDER BY ID ASC' + + +@pytest.fixture +def student_table_fixture(client): + yield from __create_student_table_fixture(client) + + +@pytest.fixture +async def async_student_table_fixture(async_client): + async for _ in __create_student_table_fixture(async_client): + yield + + +def __create_student_table_fixture(client): + create_query = '''CREATE TABLE Student ( + id INT(11) PRIMARY KEY, + first_name CHAR(24), + last_name CHAR(32), + grade INT(11))''' + + insert_query = '''INSERT INTO Student(id, first_name, last_name, grade) + VALUES (?, ?, ?, ?)''' + + drop_query = 'DROP TABLE Student IF EXISTS' + + def inner(): + client.sql(drop_query) + client.sql(create_query) + + for i, data_line in enumerate(student_table_data): + fname, lname, grade = data_line + client.sql(insert_query, query_args=[i, fname, lname, grade]) + + yield None + client.sql(drop_query) + + async def inner_async(): + await client.sql(drop_query) + await client.sql(create_query) + + for i, data_line in enumerate(student_table_data): + fname, lname, grade = data_line + await client.sql(insert_query, query_args=[i, fname, lname, grade]) + + yield None + await client.sql(drop_query) + + return inner_async() if isinstance(client, AioClient) else inner() + + +@pytest.mark.parametrize('page_size', range(1, 6, 2)) +def test_sql(client, student_table_fixture, page_size): + cache = client.get_cache('SQL_PUBLIC_STUDENT') + cache_config = cache.settings + + binary_type_name = cache_config[PROP_QUERY_ENTITIES][0]['value_type_name'] + + with cache.select_row('ORDER BY ID ASC', page_size=4) as cursor: + for i, row in enumerate(cursor): + k, v = row + assert k == i + + assert (v.FIRST_NAME, v.LAST_NAME, v.GRADE) == student_table_data[i] + assert v.type_id == entity_id(binary_type_name) + + +@pytest.mark.parametrize('page_size', range(1, 6, 2)) +def test_sql_fields(client, student_table_fixture, page_size): + with client.sql(student_table_select_query, page_size=page_size, include_field_names=True) as cursor: + for i, row in enumerate(cursor): + if i > 0: + assert tuple(row) == (i - 1,) + student_table_data[i - 1] + else: + assert row == ['ID', 'FIRST_NAME', 'LAST_NAME', 'GRADE'] + + +@pytest.mark.asyncio +@pytest.mark.parametrize('page_size', range(1, 6, 2)) +async def test_sql_fields_async(async_client, async_student_table_fixture, page_size): + async with async_client.sql(student_table_select_query, page_size=page_size, include_field_names=True) as cursor: + i = 0 + async for row in cursor: + if i > 0: + assert tuple(row) == (i - 1,) + student_table_data[i - 1] + else: + assert row == ['ID', 'FIRST_NAME', 'LAST_NAME', 'GRADE'] + i += 1 + + cursor = await async_client.sql(student_table_select_query, page_size=page_size, include_field_names=True) + try: + i = 0 + async for row in cursor: + if i > 0: + assert tuple(row) == (i - 1,) + student_table_data[i - 1] + else: + assert row == ['ID', 'FIRST_NAME', 'LAST_NAME', 'GRADE'] + i += 1 + finally: + await cursor.close() + + +multipage_fields = ["id", "abc", "ghi", "def", "jkl", "prs", "mno", "tuw", "zyz", "abc1", "def1", "jkl1", "prs1"] + + +@pytest.fixture +def long_multipage_table_fixture(client): + yield from __long_multipage_table_fixture(client) + + +@pytest.fixture +async def async_long_multipage_table_fixture(async_client): + async for _ in __long_multipage_table_fixture(async_client): + yield + + +def __long_multipage_table_fixture(client): + drop_query = 'DROP TABLE LongMultipageQuery IF EXISTS' + + create_query = "CREATE TABLE LongMultiPageQuery (%s, %s)" % ( + multipage_fields[0] + " INT(11) PRIMARY KEY", ",".join(map(lambda f: f + " INT(11)", multipage_fields[1:]))) + + insert_query = "INSERT INTO LongMultipageQuery (%s) VALUES (%s)" % ( + ",".join(multipage_fields), ",".join("?" * len(multipage_fields))) + + def query_args(_id): + return [_id] + list(i * _id for i in range(1, len(multipage_fields))) + + def inner(): + client.sql(drop_query) + client.sql(create_query) + + for i in range(1, 21): + client.sql(insert_query, query_args=query_args(i)) + yield None + + client.sql(drop_query) + + async def inner_async(): + await client.sql(drop_query) + await client.sql(create_query) + + for i in range(1, 21): + await client.sql(insert_query, query_args=query_args(i)) + yield None + + await client.sql(drop_query) + + return inner_async() if isinstance(client, AioClient) else inner() + + +def test_long_multipage_query(client, long_multipage_table_fixture): """ The test creates a table with 13 columns (id and 12 enumerated columns) and 20 records with id in range from 1 to 20. Values of enumerated columns @@ -177,25 +190,20 @@ def test_long_multipage_query(client): The goal is to ensure that all the values are selected in a right order. """ - fields = ["id", "abc", "ghi", "def", "jkl", "prs", "mno", "tuw", "zyz", "abc1", "def1", "jkl1", "prs1"] + with client.sql('SELECT * FROM LongMultipageQuery', page_size=1) as cursor: + for page in cursor: + assert len(page) == len(multipage_fields) + for field_number, value in enumerate(page[1:], start=1): + assert value == field_number * page[0] - client.sql('DROP TABLE LongMultipageQuery IF EXISTS') - client.sql("CREATE TABLE LongMultiPageQuery (%s, %s)" % - (fields[0] + " INT(11) PRIMARY KEY", ",".join(map(lambda f: f + " INT(11)", fields[1:])))) - - for id in range(1, 21): - client.sql( - "INSERT INTO LongMultipageQuery (%s) VALUES (%s)" % (",".join(fields), ",".join("?" * len(fields))), - query_args=[id] + list(i * id for i in range(1, len(fields)))) - - result = client.sql('SELECT * FROM LongMultipageQuery', page_size=1) - for page in result: - assert len(page) == len(fields) - for field_number, value in enumerate(page[1:], start=1): - assert value == field_number * page[0] - - client.sql(drop_query) +@pytest.mark.asyncio +async def test_long_multipage_query_async(async_client, async_long_multipage_table_fixture): + async with async_client.sql('SELECT * FROM LongMultipageQuery', page_size=1) as cursor: + async for page in cursor: + assert len(page) == len(multipage_fields) + for field_number, value in enumerate(page[1:], start=1): + assert value == field_number * page[0] def test_sql_not_create_cache_with_schema(client): @@ -203,20 +211,30 @@ def test_sql_not_create_cache_with_schema(client): client.sql(schema=None, cache='NOT_EXISTING', query_str='select * from NotExisting') +@pytest.mark.asyncio +async def test_sql_not_create_cache_with_schema_async(async_client): + with pytest.raises(SQLError, match=r".*Cache does not exist.*"): + await async_client.sql(schema=None, cache='NOT_EXISTING_ASYNC', query_str='select * from NotExistingAsync') + + def test_sql_not_create_cache_with_cache(client): with pytest.raises(SQLError, match=r".*Failed to set schema.*"): client.sql(schema='NOT_EXISTING', query_str='select * from NotExisting') -def test_query_with_cache(client): - test_key = 42 - test_value = 'Lorem ipsum' +@pytest.mark.asyncio +async def test_sql_not_create_cache_with_cache_async(async_client): + with pytest.raises(SQLError, match=r".*Failed to set schema.*"): + await async_client.sql(schema='NOT_EXISTING_ASYNC', query_str='select * from NotExistingAsync') - cache_name = test_query_with_cache.__name__.upper() + +@pytest.fixture +def indexed_cache_settings(): + cache_name = 'indexed_cache' schema_name = f'{cache_name}_schema'.upper() table_name = f'{cache_name}_table'.upper() - cache = client.create_cache({ + yield { PROP_NAME: cache_name, PROP_SQL_SCHEMA: schema_name, PROP_CACHE_MODE: CacheMode.PARTITIONED, @@ -243,18 +261,67 @@ def test_query_with_cache(client): ], }, ], - }) + } + + +@pytest.fixture +def indexed_cache_fixture(client, indexed_cache_settings): + cache_name = indexed_cache_settings[PROP_NAME] + schema_name = indexed_cache_settings[PROP_SQL_SCHEMA] + table_name = indexed_cache_settings[PROP_QUERY_ENTITIES][0]['table_name'] + + cache = client.create_cache(indexed_cache_settings) + + yield cache, cache_name, schema_name, table_name + cache.destroy() + + +@pytest.fixture +async def async_indexed_cache_fixture(async_client, indexed_cache_settings): + cache_name = indexed_cache_settings[PROP_NAME] + schema_name = indexed_cache_settings[PROP_SQL_SCHEMA] + table_name = indexed_cache_settings[PROP_QUERY_ENTITIES][0]['table_name'] + + cache = await async_client.create_cache(indexed_cache_settings) + + yield cache, cache_name, schema_name, table_name + await cache.destroy() + + +def test_query_with_cache(client, indexed_cache_fixture): + return __check_query_with_cache(client, indexed_cache_fixture) + + +@pytest.mark.asyncio +async def test_query_with_cache_async(async_client, async_indexed_cache_fixture): + return await __check_query_with_cache(async_client, async_indexed_cache_fixture) + - cache.put(test_key, test_value) +def __check_query_with_cache(client, cache_fixture): + test_key, test_value = 42, 'Lorem ipsum' + cache, cache_name, schema_name, table_name = cache_fixture + query = f'select value from {table_name}' args_to_check = [ ('schema', schema_name), ('cache', cache), - ('cache', cache.name), + ('cache', cache_name), ('cache', cache.cache_id) ] - for param, value in args_to_check: - page = client.sql(f'select value from {table_name}', **{param: value}) - received = next(page)[0] - assert test_value == received + def inner(): + cache.put(test_key, test_value) + for param, value in args_to_check: + with client.sql(query, **{param: value}) as cursor: + received = next(cursor)[0] + assert test_value == received + + async def async_inner(): + await cache.put(test_key, test_value) + for param, value in args_to_check: + async with client.sql(query, **{param: value}) as cursor: + row = await cursor.__anext__() + received = row[0] + assert test_value == received + + return async_inner() if isinstance(cache, AioCache) else inner() diff --git a/tests/common/test_sql_composite_key.py b/tests/common/test_sql_composite_key.py new file mode 100644 index 0000000..76de77e --- /dev/null +++ b/tests/common/test_sql_composite_key.py @@ -0,0 +1,168 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +from enum import Enum + +import pytest + +from pyignite import GenericObjectMeta, AioClient +from pyignite.datatypes import IntObject, String + + +class StudentKey( + metaclass=GenericObjectMeta, + type_name='test.model.StudentKey', + schema=OrderedDict([ + ('ID', IntObject), + ('DEPT', String) + ]) +): + pass + + +class Student( + metaclass=GenericObjectMeta, + type_name='test.model.Student', + schema=OrderedDict([ + ('NAME', String), + ]) +): + pass + + +create_query = '''CREATE TABLE StudentTable ( + id INT(11), + dept VARCHAR, + name CHAR(24), + PRIMARY KEY (id, dept)) + WITH "CACHE_NAME=StudentCache, KEY_TYPE=test.model.StudentKey, VALUE_TYPE=test.model.Student"''' + +insert_query = '''INSERT INTO StudentTable (id, dept, name) VALUES (?, ?, ?)''' + +select_query = 'SELECT id, dept, name FROM StudentTable' + +select_kv_query = 'SELECT _key, _val FROM StudentTable' + +drop_query = 'DROP TABLE StudentTable IF EXISTS' + + +@pytest.fixture +def student_table_fixture(client): + yield from __create_student_table_fixture(client) + + +@pytest.fixture +async def async_student_table_fixture(async_client): + async for _ in __create_student_table_fixture(async_client): + yield + + +def __create_student_table_fixture(client): + def inner(): + client.sql(drop_query) + client.sql(create_query) + yield None + client.sql(drop_query) + + async def inner_async(): + await client.sql(drop_query) + await client.sql(create_query) + yield None + await client.sql(drop_query) + + return inner_async() if isinstance(client, AioClient) else inner() + + +class InsertMode(Enum): + SQL = 1 + CACHE = 2 + + +@pytest.mark.parametrize('insert_mode', [InsertMode.SQL, InsertMode.CACHE]) +def test_sql_composite_key(client, insert_mode, student_table_fixture): + __perform_test(client, insert_mode) + + +@pytest.mark.asyncio +@pytest.mark.parametrize('insert_mode', [InsertMode.SQL, InsertMode.CACHE]) +async def test_sql_composite_key_async(async_client, insert_mode, async_student_table_fixture): + await __perform_test(async_client, insert_mode) + + +def __perform_test(client, insert=InsertMode.SQL): + student_key = StudentKey(2, 'Business') + student_val = Student('Abe') + + def validate_query_result(key, val, query_result): + """ + Compare query result with expected key and value. + """ + assert len(query_result) == 2 + sql_row = dict(zip(query_result[0], query_result[1])) + + assert sql_row['ID'] == key.ID + assert sql_row['DEPT'] == key.DEPT + assert sql_row['NAME'] == val.NAME + + def validate_kv_query_result(key, val, query_result): + """ + Compare query result with expected key and value. + """ + assert len(query_result) == 2 + sql_row = dict(zip(query_result[0], query_result[1])) + + sql_key, sql_val = sql_row['_KEY'], sql_row['_VAL'] + assert sql_key.ID == key.ID + assert sql_key.DEPT == key.DEPT + assert sql_val.NAME == val.NAME + + def inner(): + if insert == InsertMode.SQL: + result = client.sql(insert_query, query_args=[student_key.ID, student_key.DEPT, student_val.NAME]) + assert next(result)[0] == 1 + else: + studentCache = client.get_cache('StudentCache') + studentCache.put(student_key, student_val) + val = studentCache.get(student_key) + assert val is not None + assert val.NAME == student_val.NAME + + query_result = list(client.sql(select_query, include_field_names=True)) + validate_query_result(student_key, student_val, query_result) + + query_result = list(client.sql(select_kv_query, include_field_names=True)) + validate_kv_query_result(student_key, student_val, query_result) + + async def inner_async(): + if insert == InsertMode.SQL: + result = await client.sql(insert_query, query_args=[student_key.ID, student_key.DEPT, student_val.NAME]) + assert (await result.__anext__())[0] == 1 + else: + studentCache = await client.get_cache('StudentCache') + await studentCache.put(student_key, student_val) + val = await studentCache.get(student_key) + assert val is not None + assert val.NAME == student_val.NAME + + async with client.sql(select_query, include_field_names=True) as cursor: + query_result = [r async for r in cursor] + validate_query_result(student_key, student_val, query_result) + + async with client.sql(select_kv_query, include_field_names=True) as cursor: + query_result = [r async for r in cursor] + validate_kv_query_result(student_key, student_val, query_result) + + return inner_async() if isinstance(client, AioClient) else inner() diff --git a/tests/conftest.py b/tests/conftest.py index 59b7d3a..65134fd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio + import pytest @@ -27,7 +29,7 @@ def run_examples(request): def skip_if_no_cext(request): skip = False try: - from pyignite import _cutils + from pyignite import _cutils # noqa: F401 except ImportError: if request.config.getoption('--force-cext'): pytest.fail("C extension failed to build, fail test because of --force-cext is set.") @@ -38,6 +40,14 @@ def skip_if_no_cext(request): pytest.skip('skipped c extensions test, c extension is not available.') +@pytest.fixture(scope='session') +def event_loop(): + """Create an instance of the default event loop for each test case.""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + def pytest_addoption(parser): parser.addoption( '--examples', diff --git a/tests/security/test_auth.py b/tests/security/test_auth.py index 2dd19a0..4a1c52d 100644 --- a/tests/security/test_auth.py +++ b/tests/security/test_auth.py @@ -15,7 +15,7 @@ import pytest from pyignite.exceptions import AuthenticationError -from tests.util import start_ignite_gen, clear_ignite_work_dir, get_client +from tests.util import start_ignite_gen, clear_ignite_work_dir, get_client, get_client_async DEFAULT_IGNITE_USERNAME = 'ignite' DEFAULT_IGNITE_PASSWORD = 'ignite' @@ -47,13 +47,27 @@ def test_auth_success(with_ssl, ssl_params): assert all(node.alive for node in client._nodes) +@pytest.mark.asyncio +async def test_auth_success_async(with_ssl, ssl_params): + ssl_params['use_ssl'] = with_ssl + + async with get_client_async(username=DEFAULT_IGNITE_USERNAME, password=DEFAULT_IGNITE_PASSWORD, + **ssl_params) as client: + await client.connect("127.0.0.1", 10801) + + assert all(node.alive for node in client._nodes) + + +auth_failed_params = [ + [DEFAULT_IGNITE_USERNAME, None], + ['invalid_user', 'invalid_password'], + [None, None] +] + + @pytest.mark.parametrize( 'username, password', - [ - [DEFAULT_IGNITE_USERNAME, None], - ['invalid_user', 'invalid_password'], - [None, None] - ] + auth_failed_params ) def test_auth_failed(username, password, with_ssl, ssl_params): ssl_params['use_ssl'] = with_ssl @@ -61,3 +75,16 @@ def test_auth_failed(username, password, with_ssl, ssl_params): with pytest.raises(AuthenticationError): with get_client(username=username, password=password, **ssl_params) as client: client.connect("127.0.0.1", 10801) + + +@pytest.mark.parametrize( + 'username, password', + auth_failed_params +) +@pytest.mark.asyncio +async def test_auth_failed_async(username, password, with_ssl, ssl_params): + ssl_params['use_ssl'] = with_ssl + + with pytest.raises(AuthenticationError): + async with get_client_async(username=username, password=password, **ssl_params) as client: + await client.connect("127.0.0.1", 10801) diff --git a/tests/security/test_ssl.py b/tests/security/test_ssl.py index 6463a03..32db98f 100644 --- a/tests/security/test_ssl.py +++ b/tests/security/test_ssl.py @@ -15,7 +15,7 @@ import pytest from pyignite.exceptions import ReconnectError -from tests.util import start_ignite_gen, get_client, get_or_create_cache +from tests.util import start_ignite_gen, get_client, get_or_create_cache, get_client_async, get_or_create_cache_async @pytest.fixture(scope='module', autouse=True) @@ -30,27 +30,58 @@ def test_connect_ssl_keystore_with_password(ssl_params_with_password): def test_connect_ssl(ssl_params): __test_connect_ssl(**ssl_params) -def __test_connect_ssl(**kwargs): + +@pytest.mark.asyncio +async def test_connect_ssl_keystore_with_password_async(ssl_params_with_password): + await __test_connect_ssl(is_async=True, **ssl_params_with_password) + + +@pytest.mark.asyncio +async def test_connect_ssl_async(ssl_params): + await __test_connect_ssl(is_async=True, **ssl_params) + + +def __test_connect_ssl(is_async=False, **kwargs): kwargs['use_ssl'] = True - with get_client(**kwargs) as client: - client.connect("127.0.0.1", 10801) + def inner(): + with get_client(**kwargs) as client: + client.connect("127.0.0.1", 10801) + + with get_or_create_cache(client, 'test-cache') as cache: + cache.put(1, 1) + + assert cache.get(1) == 1 - with get_or_create_cache(client, 'test-cache') as cache: - cache.put(1, 1) + async def inner_async(): + async with get_client_async(**kwargs) as client: + await client.connect("127.0.0.1", 10801) - assert cache.get(1) == 1 + async with get_or_create_cache_async(client, 'test-cache') as cache: + await cache.put(1, 1) + assert (await cache.get(1)) == 1 -@pytest.mark.parametrize( - 'invalid_ssl_params', - [ - {'use_ssl': False}, - {'use_ssl': True}, - {'use_ssl': True, 'ssl_keyfile': 'invalid.pem', 'ssl_certfile': 'invalid.pem'} - ] -) + return inner_async() if is_async else inner() + + +invalid_params = [ + {'use_ssl': False}, + {'use_ssl': True}, + {'use_ssl': True, 'ssl_keyfile': 'invalid.pem', 'ssl_certfile': 'invalid.pem'} +] + + +@pytest.mark.parametrize('invalid_ssl_params', invalid_params) def test_connection_error_with_incorrect_config(invalid_ssl_params): with pytest.raises(ReconnectError): with get_client(**invalid_ssl_params) as client: client.connect([("127.0.0.1", 10801)]) + + +@pytest.mark.parametrize('invalid_ssl_params', invalid_params) +@pytest.mark.asyncio +async def test_connection_error_with_incorrect_config_async(invalid_ssl_params): + with pytest.raises(ReconnectError): + async with get_client_async(**invalid_ssl_params) as client: + await client.connect([("127.0.0.1", 10801)]) diff --git a/tests/test_cutils.py b/tests/test_cutils.py index e7c095e..d66425f 100644 --- a/tests/test_cutils.py +++ b/tests/test_cutils.py @@ -27,8 +27,8 @@ _cutils_hashcode = _cutils.hashcode _cutils_schema_id = _cutils.schema_id except ImportError: - _cutils_hashcode = lambda x: None - _cutils_schema_id = lambda x: None + _cutils_hashcode = lambda x: None # noqa: E731 + _cutils_schema_id = lambda x: None # noqa: E731 pass diff --git a/tests/util.py b/tests/util.py index af4c324..f1243fc 100644 --- a/tests/util.py +++ b/tests/util.py @@ -12,8 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import contextlib import glob +import inspect import os import shutil @@ -24,7 +26,12 @@ import subprocess import time -from pyignite import Client +from pyignite import Client, AioClient + +try: + from contextlib import asynccontextmanager +except ImportError: + from async_generator import asynccontextmanager @contextlib.contextmanager @@ -36,6 +43,15 @@ def get_client(**kwargs): client.close() +@asynccontextmanager +async def get_client_async(**kwargs): + client = AioClient(**kwargs) + try: + yield client + finally: + await client.close() + + @contextlib.contextmanager def get_or_create_cache(client, cache_name): cache = client.get_or_create_cache(cache_name) @@ -45,6 +61,15 @@ def get_or_create_cache(client, cache_name): cache.destroy() +@asynccontextmanager +async def get_or_create_cache_async(client, cache_name): + cache = await client.get_or_create_cache(cache_name) + try: + yield cache + finally: + await cache.destroy() + + def wait_for_condition(condition, interval=0.1, timeout=10, error=None): start = time.time() res = condition() @@ -62,6 +87,23 @@ def wait_for_condition(condition, interval=0.1, timeout=10, error=None): return False +async def wait_for_condition_async(condition, interval=0.1, timeout=10, error=None): + start = time.time() + res = await condition() if inspect.iscoroutinefunction(condition) else condition() + + while not res and time.time() - start < timeout: + await asyncio.sleep(interval) + res = await condition() if inspect.iscoroutinefunction(condition) else condition() + + if res: + return True + + if error is not None: + raise Exception(error) + + return False + + def is_windows(): return os.name == "nt" diff --git a/tox.ini b/tox.ini index 3ab8dea..90153da 100644 --- a/tox.ini +++ b/tox.ini @@ -15,7 +15,15 @@ [tox] skipsdist = True -envlist = py{36,37,38,39} +envlist = codestyle,py{36,37,38,39} + +[flake8] +max-line-length=120 +ignore = F401,F403,F405,F821 + +[testenv:codestyle] +basepython = python3.8 +commands = flake8 [testenv] passenv = TEAMCITY_VERSION IGNITE_HOME