diff --git a/README.md b/README.md index 0b76791..8748cab 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # e6data Python Connector -![version](https://img.shields.io/badge/version-1.1.4-blue.svg) +![version](https://img.shields.io/badge/version-1.1.5-blue.svg) ## Introduction diff --git a/e6data_python_connector/dialect.py b/e6data_python_connector/dialect.py index aba21b7..8e96a5f 100644 --- a/e6data_python_connector/dialect.py +++ b/e6data_python_connector/dialect.py @@ -211,7 +211,6 @@ class E6dataDialect(default.DefaultDialect): type_compiler = E6dataTypeCompiler supports_sane_rowcount = False driver = b'thrift' - name = b'E6data' scheme = 'e6data' catalog_name = None @@ -223,19 +222,20 @@ def dbapi(cls): return e6data_grpc def create_connect_args(self, url): - db = None + database = None if url.query.get("schema"): - db = url.query.get("schema") + database = url.query.get("schema") self.catalog_name = url.query.get("catalog") if not self.catalog_name: raise Exception('Please specify catalog in query parameter.') + kwargs = { "host": url.host, "port": url.port, "scheme": self.scheme, "username": url.username or None, "password": url.password or None, - "database": db, + "database": database, "catalog": self.catalog_name } return [], kwargs @@ -243,7 +243,6 @@ def create_connect_args(self, url): def get_schema_names(self, connection, **kw): # Equivalent to SHOW DATABASES # Rerouting to view names - engine = connection if isinstance(connection, Engine): cursor = connection.raw_connection().connection.cursor(catalog_name=self.catalog_name) elif isinstance(connection, Connection): @@ -252,12 +251,12 @@ def get_schema_names(self, connection, **kw): raise Exception("Got type of object {typ}".format(typ=type(connection))) client = cursor.connection - return client.get_schema_names() + return client.get_schema_names(catalog=self.catalog_name) def get_view_names(self, connection, schema=None, **kw): return [] - def _get_table_columns(self, connection, table): + def _get_table_columns(self, connection, schema, table): try: if isinstance(connection, Engine): cursor = connection.raw_connection().connection.cursor(catalog_name=self.catalog_name) @@ -267,54 +266,43 @@ def _get_table_columns(self, connection, table): raise Exception("Got type of object {typ}".format(typ=type(connection))) client = cursor.connection - columns = client.getColumns("default", table) + columns = client.get_columns(self, self.catalog_name, schema, table) rows = list() for column in columns: row = dict() - row["col_name"] = column.fieldName - row["data_type"] = column.fieldType + row["col_name"] = column.get('fieldName') + row["data_type"] = column.get('fieldType') rows.append(row) - return rows except exc.OperationalError as e: # Does the table exist? raise e def has_table(self, connection, table_name, schema=None, **kwargs): - try: - self._get_table_columns(connection, table_name) - return True - except Exception: - return False - - def get_columns(self, connection, table_name, schema=None, **kw): - rows = self._get_table_columns(connection, table_name) - # # Strip whitespace - # rows = [[col.strip() if col else None for col in row] for row in rows] - # Filter out empty rows and comment - # rows = [row for row in rows if row[0] and row[0] != '# col_name'] - result = [] - for row in rows: - col_name = row['col_name'] - col_type = row['data_type'] - # Take out the more detailed type information - # e.g. 'map' -> 'map' - # 'decimal(10,1)' -> decimal - col_type = re.search(r'^\w+', col_type).group(0) - try: - coltype = _type_map[col_type.lower()] - _logger.info("Got column {column} with data type {dt}".format(column=col_name, dt=coltype)) - except KeyError: - util.warn("Did not recognize type '%s' of column '%s'" % (col_type, col_name)) - coltype = types.NullType - - result.append({ - 'name': col_name, - 'type': coltype, - 'nullable': True, - 'default': None, - }) - return result + return True + # try: + # self._get_table_columns(connection, schema, table_name) + # return True + # except Exception as e: + # return False + + def get_columns(self, connection, table_name, schema, **kwargs): + if isinstance(connection, Engine): + cursor = connection.raw_connection().connection.cursor(catalog_name=self.catalog_name) + elif isinstance(connection, Connection): + cursor = connection.connection.cursor(catalog_name=self.catalog_name) + else: + raise Exception("Got type of object {typ}".format(typ=type(connection))) + + client = cursor.connection + columns = client.get_columns(self.catalog_name, schema, table_name) + rows = list() + for column in columns: + row = dict() + row["name"] = column.get('fieldName') + row["type"] = lambda: column.get('fieldType') + rows.append(row) + return rows def get_foreign_keys(self, connection, table_name, schema=None, **kw): # Hive has no support for foreign keys. @@ -333,12 +321,12 @@ def get_table_names(self, connection, schema=None, **kw): if isinstance(connection, Engine): cursor = connection.raw_connection().connection.cursor(catalog_name=self.catalog_name) elif isinstance(connection, Connection): - cursor = connection.connection.cursor() + cursor = connection.connection.cursor(catalog_name=self.catalog_name) else: raise Exception("Got type of object {typ}".format(typ=type(connection))) client = cursor.connection - return client.getTables(schema) + return client.get_tables(self.catalog_name, schema) def do_rollback(self, dbapi_connection): # No transactions for Hive diff --git a/e6data_python_connector/e6data_grpc.py b/e6data_python_connector/e6data_grpc.py index 428b686..48f270e 100644 --- a/e6data_python_connector/e6data_grpc.py +++ b/e6data_python_connector/e6data_grpc.py @@ -225,24 +225,32 @@ def dry_run(self, query): dry_run_response = self._client.dryRun(dry_run_request) return dry_run_response.dryrunValue - def get_tables(self, database): - get_table_request = e6x_engine_pb2.GetTablesRequest(sessionId=self.get_session_id, schema=database) - get_table_response = self._client.getTables(get_table_request) - return get_table_response.tables + def get_tables(self, catalog, database): + get_table_request = e6x_engine_pb2.GetTablesV2Request( + sessionId=self.get_session_id, + schema=database, + catalog=catalog + ) + get_table_response = self._client.getTablesV2(get_table_request) + return list(get_table_response.tables) - def get_columns(self, database, table): - get_columns_request = e6x_engine_pb2.GetColumnsRequest( + def get_columns(self, catalog, database, table): + get_columns_request = e6x_engine_pb2.GetColumnsV2Request( sessionId=self.get_session_id, schema=database, - table=table + table=table, + catalog=catalog ) - get_columns_response = self._client.getColumns(get_columns_request) - return get_columns_response.fieldInfo + get_columns_response = self._client.getColumnsV2(get_columns_request) + return [{'fieldName': row.fieldName, 'fieldType': row.fieldType} for row in get_columns_response.fieldInfo] - def get_schema_names(self): - get_schema_request = e6x_engine_pb2.GetSchemaNamesRequest(sessionId=self.get_session_id) - get_schema_response = self._client.getSchemaNames(get_schema_request) - return get_schema_response.schemas + def get_schema_names(self, catalog): + get_schema_request = e6x_engine_pb2.GetSchemaNamesV2Request( + sessionId=self.get_session_id, + catalog=catalog + ) + get_schema_response = self._client.getSchemaNamesV2(get_schema_request) + return list(get_schema_response.schemas) def commit(self): """We do not support transactions, so this does nothing.""" diff --git a/setup.py b/setup.py index b1621a2..5412722 100644 --- a/setup.py +++ b/setup.py @@ -10,13 +10,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os - import setuptools -envstring = lambda var: os.environ.get(var) or "" - -VERSION = [1, 1, 4] +VERSION = [1, 1, 5] def get_long_desc():