Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# e6data Python Connector

![version](https://img.shields.io/badge/version-1.1.4-blue.svg)
![version](https://img.shields.io/badge/version-1.1.5-blue.svg)

## Introduction

Expand Down
82 changes: 35 additions & 47 deletions e6data_python_connector/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@ class E6dataDialect(default.DefaultDialect):
type_compiler = E6dataTypeCompiler
supports_sane_rowcount = False
driver = b'thrift'
name = b'E6data'
scheme = 'e6data'
catalog_name = None

Expand All @@ -223,27 +222,27 @@ def dbapi(cls):
return e6data_grpc

def create_connect_args(self, url):
db = None
database = None
if url.query.get("schema"):
db = url.query.get("schema")
database = url.query.get("schema")
self.catalog_name = url.query.get("catalog")
if not self.catalog_name:
raise Exception('Please specify catalog in query parameter.')

kwargs = {
"host": url.host,
"port": url.port,
"scheme": self.scheme,
"username": url.username or None,
"password": url.password or None,
"database": db,
"database": database,
"catalog": self.catalog_name
}
return [], kwargs

def get_schema_names(self, connection, **kw):
# Equivalent to SHOW DATABASES
# Rerouting to view names
engine = connection
if isinstance(connection, Engine):
cursor = connection.raw_connection().connection.cursor(catalog_name=self.catalog_name)
elif isinstance(connection, Connection):
Expand All @@ -252,12 +251,12 @@ def get_schema_names(self, connection, **kw):
raise Exception("Got type of object {typ}".format(typ=type(connection)))

client = cursor.connection
return client.get_schema_names()
return client.get_schema_names(catalog=self.catalog_name)

def get_view_names(self, connection, schema=None, **kw):
return []

def _get_table_columns(self, connection, table):
def _get_table_columns(self, connection, schema, table):
try:
if isinstance(connection, Engine):
cursor = connection.raw_connection().connection.cursor(catalog_name=self.catalog_name)
Expand All @@ -267,54 +266,43 @@ def _get_table_columns(self, connection, table):
raise Exception("Got type of object {typ}".format(typ=type(connection)))

client = cursor.connection
columns = client.getColumns("default", table)
columns = client.get_columns(self, self.catalog_name, schema, table)
rows = list()
for column in columns:
row = dict()
row["col_name"] = column.fieldName
row["data_type"] = column.fieldType
row["col_name"] = column.get('fieldName')
row["data_type"] = column.get('fieldType')
rows.append(row)

return rows
except exc.OperationalError as e:
# Does the table exist?
raise e

def has_table(self, connection, table_name, schema=None, **kwargs):
try:
self._get_table_columns(connection, table_name)
return True
except Exception:
return False

def get_columns(self, connection, table_name, schema=None, **kw):
rows = self._get_table_columns(connection, table_name)
# # Strip whitespace
# rows = [[col.strip() if col else None for col in row] for row in rows]
# Filter out empty rows and comment
# rows = [row for row in rows if row[0] and row[0] != '# col_name']
result = []
for row in rows:
col_name = row['col_name']
col_type = row['data_type']
# Take out the more detailed type information
# e.g. 'map<int,int>' -> 'map'
# 'decimal(10,1)' -> decimal
col_type = re.search(r'^\w+', col_type).group(0)
try:
coltype = _type_map[col_type.lower()]
_logger.info("Got column {column} with data type {dt}".format(column=col_name, dt=coltype))
except KeyError:
util.warn("Did not recognize type '%s' of column '%s'" % (col_type, col_name))
coltype = types.NullType

result.append({
'name': col_name,
'type': coltype,
'nullable': True,
'default': None,
})
return result
return True
# try:
# self._get_table_columns(connection, schema, table_name)
# return True
# except Exception as e:
# return False

def get_columns(self, connection, table_name, schema, **kwargs):
if isinstance(connection, Engine):
cursor = connection.raw_connection().connection.cursor(catalog_name=self.catalog_name)
elif isinstance(connection, Connection):
cursor = connection.connection.cursor(catalog_name=self.catalog_name)
else:
raise Exception("Got type of object {typ}".format(typ=type(connection)))

client = cursor.connection
columns = client.get_columns(self.catalog_name, schema, table_name)
rows = list()
for column in columns:
row = dict()
row["name"] = column.get('fieldName')
row["type"] = lambda: column.get('fieldType')
rows.append(row)
return rows

def get_foreign_keys(self, connection, table_name, schema=None, **kw):
# Hive has no support for foreign keys.
Expand All @@ -333,12 +321,12 @@ def get_table_names(self, connection, schema=None, **kw):
if isinstance(connection, Engine):
cursor = connection.raw_connection().connection.cursor(catalog_name=self.catalog_name)
elif isinstance(connection, Connection):
cursor = connection.connection.cursor()
cursor = connection.connection.cursor(catalog_name=self.catalog_name)
else:
raise Exception("Got type of object {typ}".format(typ=type(connection)))

client = cursor.connection
return client.getTables(schema)
return client.get_tables(self.catalog_name, schema)

def do_rollback(self, dbapi_connection):
# No transactions for Hive
Expand Down
34 changes: 21 additions & 13 deletions e6data_python_connector/e6data_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,24 +225,32 @@ def dry_run(self, query):
dry_run_response = self._client.dryRun(dry_run_request)
return dry_run_response.dryrunValue

def get_tables(self, database):
get_table_request = e6x_engine_pb2.GetTablesRequest(sessionId=self.get_session_id, schema=database)
get_table_response = self._client.getTables(get_table_request)
return get_table_response.tables
def get_tables(self, catalog, database):
get_table_request = e6x_engine_pb2.GetTablesV2Request(
sessionId=self.get_session_id,
schema=database,
catalog=catalog
)
get_table_response = self._client.getTablesV2(get_table_request)
return list(get_table_response.tables)

def get_columns(self, database, table):
get_columns_request = e6x_engine_pb2.GetColumnsRequest(
def get_columns(self, catalog, database, table):
get_columns_request = e6x_engine_pb2.GetColumnsV2Request(
sessionId=self.get_session_id,
schema=database,
table=table
table=table,
catalog=catalog
)
get_columns_response = self._client.getColumns(get_columns_request)
return get_columns_response.fieldInfo
get_columns_response = self._client.getColumnsV2(get_columns_request)
return [{'fieldName': row.fieldName, 'fieldType': row.fieldType} for row in get_columns_response.fieldInfo]

def get_schema_names(self):
get_schema_request = e6x_engine_pb2.GetSchemaNamesRequest(sessionId=self.get_session_id)
get_schema_response = self._client.getSchemaNames(get_schema_request)
return get_schema_response.schemas
def get_schema_names(self, catalog):
get_schema_request = e6x_engine_pb2.GetSchemaNamesV2Request(
sessionId=self.get_session_id,
catalog=catalog
)
get_schema_response = self._client.getSchemaNamesV2(get_schema_request)
return list(get_schema_response.schemas)

def commit(self):
"""We do not support transactions, so this does nothing."""
Expand Down
6 changes: 1 addition & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import setuptools

envstring = lambda var: os.environ.get(var) or ""

VERSION = [1, 1, 4]
VERSION = [1, 1, 5]


def get_long_desc():
Expand Down