diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 5ea3f14ac380..51f3f917f2df 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -6,7 +6,6 @@ import pandas as pd import sqlalchemy import uuid -import zlib from sqlalchemy.pool import NullPool from sqlalchemy.orm import sessionmaker @@ -90,8 +89,8 @@ def handle_error(msg): executed_sql = superset_query.as_create_table(query.tmp_table_name) query.select_as_cta_used = True elif ( - query.limit and superset_query.is_select() and - db_engine_spec.limit_method == LimitMethod.WRAP_SQL): + query.limit and superset_query.is_select() and + db_engine_spec.limit_method == LimitMethod.WRAP_SQL): executed_sql = database.wrap_sql_limit(executed_sql, query.limit) query.limit_used = True try: @@ -170,7 +169,7 @@ def handle_error(msg): if store_results: key = '{}'.format(uuid.uuid4()) logging.info("Storing results in results backend, key: {}".format(key)) - results_backend.set(key, zlib.compress(payload)) + results_backend.set(key, utils.zlib_compress(payload)) query.results_key = key session.flush() diff --git a/superset/utils.py b/superset/utils.py index b2a724cf8dc3..f601c70157e7 100644 --- a/superset/utils.py +++ b/superset/utils.py @@ -15,7 +15,9 @@ import smtplib import sqlalchemy as sa import signal +import sys import uuid +import zlib from builtins import object from datetime import date, datetime, time @@ -41,7 +43,7 @@ logging.getLogger('MARKDOWN').setLevel(logging.INFO) - +PY3K = sys.version_info >= (3, 0) EPOCH = datetime(1970, 1, 1) DTTM_ALIAS = '__timestamp' @@ -580,3 +582,44 @@ def setup_cache(app, cache_config): """Setup the flask-cache on a flask app""" if cache_config and cache_config.get('CACHE_TYPE') != 'null': return Cache(app, config=cache_config) + + +def zlib_compress(data): + """ + compress things in a py2/3 safe fashion + + >>> json_str = '{"test": 1}' + >>> blob = zlib_compress(json_str) + """ + + if PY3K: + if isinstance(data, str): + return zlib.compress(bytes(data, "utf-8")) + else: + return zlib.compress(data) + else: + return zlib.compress(data) + + +def zlib_uncompress_to_string(blob): + """ + uncompress things to a string in a py2/3 safe fashion + >>> json_str = '{"test": 1}' + >>> blob = zlib_compress(json_str) + >>> got_str = zlib_uncompress_to_string(blob) + >>> got_str == json_str + True + """ + + if PY3K: + decompressed = "" + if isinstance(blob, bytes): + decompressed = zlib.decompress(blob) + else: + decompressed = zlib.decompress(bytes(blob, "utf-8")) + + if isinstance(decompressed, str): + return decompressed + return decompressed.decode("utf-8") + else: + return zlib.decompress(blob) diff --git a/superset/views/core.py b/superset/views/core.py index eaca53b2d413..588e492c93a9 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -12,7 +12,6 @@ import sys import time import traceback -import zlib import functools import sqlalchemy as sqla @@ -130,7 +129,7 @@ def check_ownership(obj, raise_if_false=True): return False security_exception = utils.SupersetSecurityException( - "You don't have the rights to alter [{}]".format(obj)) + "You don't have the rights to alter [{}]".format(obj)) if g.user.is_anonymous(): if raise_if_false: @@ -143,15 +142,15 @@ def check_ownership(obj, raise_if_false=True): orig_obj = session.query(obj.__class__).filter_by(id=obj.id).first() owner_names = (user.username for user in orig_obj.owners) if ( - hasattr(orig_obj, 'created_by') and - orig_obj.created_by and - orig_obj.created_by.username == g.user.username): + hasattr(orig_obj, 'created_by') and + orig_obj.created_by and + orig_obj.created_by.username == g.user.username): return True if ( - hasattr(orig_obj, 'owners') and - g.user and - hasattr(g.user, 'username') and - g.user.username in owner_names): + hasattr(orig_obj, 'owners') and + g.user and + hasattr(g.user, 'username') and + g.user.username in owner_names): return True if raise_if_false: raise security_exception @@ -181,15 +180,15 @@ def apply(self, query, func): # noqa datasource_perms = self.get_view_menus('datasource_access') slice_ids_qry = ( db.session - .query(Slice.id) - .filter(Slice.perm.in_(datasource_perms)) + .query(Slice.id) + .filter(Slice.perm.in_(datasource_perms)) ) query = query.filter( Dash.id.in_( db.session.query(Dash.id) - .distinct() - .join(Dash.slices) - .filter(Slice.id.in_(slice_ids_qry)) + .distinct() + .join(Dash.slices) + .filter(Slice.id.in_(slice_ids_qry)) ) ) return query @@ -762,8 +761,8 @@ def override_role_permissions(self): granted_perms = [] for datasource in datasources: view_menu_perm = sm.find_permission_view_menu( - view_menu_name=datasource.perm, - permission_name='datasource_access') + view_menu_name=datasource.perm, + permission_name='datasource_access') # prevent creating empty permissions if view_menu_perm and view_menu_perm.view_menu: role.permissions.append(view_menu_perm) @@ -783,8 +782,8 @@ def request_access(self): if dashboard_id: dash = ( db.session.query(models.Dashboard) - .filter_by(id=int(dashboard_id)) - .one() + .filter_by(id=int(dashboard_id)) + .one() ) datasources |= dash.datasources datasource_id = request.args.get('datasource_id') @@ -793,8 +792,8 @@ def request_access(self): ds_class = ConnectorRegistry.sources.get(datasource_type) datasource = ( db.session.query(ds_class) - .filter_by(id=int(datasource_id)) - .one() + .filter_by(id=int(datasource_id)) + .one() ) datasources.add(datasource) if request.args.get('action') == 'go': @@ -823,7 +822,7 @@ def clean_fulfilled_requests(session): r.datasource_type, r.datasource_id, session) user = sm.get_user_by_id(r.created_by_fk) if not datasource or \ - self.datasource_access(datasource, user): + self.datasource_access(datasource, user): # datasource does not exist anymore session.delete(r) session.commit() @@ -848,11 +847,11 @@ def clean_fulfilled_requests(session): requests = ( session.query(DAR) - .filter( + .filter( DAR.datasource_id == datasource_id, DAR.datasource_type == datasource_type, DAR.created_by_fk == requested_by.id) - .all() + .all() ) if not requests: @@ -930,8 +929,8 @@ def get_viz( if slice_id: slc = ( db.session.query(models.Slice) - .filter_by(id=slice_id) - .one() + .filter_by(id=slice_id) + .one() ) return slc.get_viz() else: @@ -951,7 +950,7 @@ def slice(self, slice_id): viz_obj = self.get_viz(slice_id) endpoint = ( '/superset/explore/{}/{}?form_data={}' - .format( + .format( viz_obj.datasource.type, viz_obj.datasource.id, json.dumps(viz_obj.form_data) @@ -1059,8 +1058,8 @@ def explore(self, datasource_type, datasource_id): error_redirect = '/slicemodelview/list/' datasource = ( db.session.query(ConnectorRegistry.sources[datasource_type]) - .filter_by(id=datasource_id) - .one() + .filter_by(id=datasource_id) + .one() ) if not datasource: @@ -1192,8 +1191,8 @@ def save_or_overwrite_slice( if request.args.get('add_to_dash') == 'existing': dash = ( db.session.query(models.Dashboard) - .filter_by(id=int(request.args.get('save_to_dashboard_id'))) - .one() + .filter_by(id=int(request.args.get('save_to_dashboard_id'))) + .one() ) flash( "Slice [{}] was added to dashboard [{}]".format( @@ -1257,14 +1256,14 @@ def activity_per_day(self): Log = models.Log # noqa qry = ( db.session - .query( + .query( Log.dt, sqla.func.count()) - .group_by(Log.dt) - .all() + .group_by(Log.dt) + .all() ) payload = {str(time.mktime(dt.timetuple())): - ccount for dt, ccount in qry if dt} + ccount for dt, ccount in qry if dt} return json_success(json.dumps(payload)) @api @@ -1273,9 +1272,9 @@ def activity_per_day(self): def schemas(self, db_id): database = ( db.session - .query(models.Database) - .filter_by(id=db_id) - .one() + .query(models.Database) + .filter_by(id=db_id) + .one() ) return Response( json.dumps({'schemas': database.all_schema_names()}), @@ -1404,9 +1403,9 @@ def testconn(self): if db_name: database = ( db.session - .query(models.Database) - .filter_by(database_name=db_name) - .first() + .query(models.Database) + .filter_by(database_name=db_name) + .first() ) if database and uri == database.safe_sqlalchemy_uri(): # the password-masked uri was passed @@ -1414,16 +1413,16 @@ def testconn(self): uri = database.sqlalchemy_uri_decrypted connect_args = ( request.json - .get('extras', {}) - .get('engine_params', {}) - .get('connect_args', {})) + .get('extras', {}) + .get('engine_params', {}) + .get('connect_args', {})) engine = create_engine(uri, connect_args=connect_args) engine.connect() return json.dumps(engine.table_names(), indent=4) except Exception as e: return json_error_response(( - "Connection failed!\n\n" - "The error message returned was:\n{}").format(e)) + "Connection failed!\n\n" + "The error message returned was:\n{}").format(e)) @api @has_access_api @@ -1433,22 +1432,22 @@ def recent_activity(self, user_id): M = models # noqa qry = ( db.session.query(M.Log, M.Dashboard, M.Slice) - .outerjoin( + .outerjoin( M.Dashboard, M.Dashboard.id == M.Log.dashboard_id ) - .outerjoin( + .outerjoin( M.Slice, M.Slice.id == M.Log.slice_id ) - .filter( + .filter( sqla.and_( ~M.Log.action.in_(('queries', 'shortner', 'sql_json')), M.Log.user_id == user_id, - ) + ) ) - .order_by(M.Log.dttm.desc()) - .limit(1000) + .order_by(M.Log.dttm.desc()) + .limit(1000) ) payload = [] for log in qry.all(): @@ -1479,15 +1478,15 @@ def fave_dashboards(self, user_id): models.Dashboard, models.FavStar.dttm, ) - .join( + .join( models.FavStar, sqla.and_( models.FavStar.user_id == int(user_id), models.FavStar.class_name == 'Dashboard', models.Dashboard.id == models.FavStar.obj_id, - ) + ) ) - .order_by( + .order_by( models.FavStar.dttm.desc() ) ) @@ -1518,23 +1517,23 @@ def created_dashboards(self, user_id): db.session.query( Dash, ) - .filter( + .filter( sqla.or_( Dash.created_by_fk == user_id, Dash.changed_by_fk == user_id, - ) + ) ) - .order_by( + .order_by( Dash.changed_on.desc() ) ) payload = [{ - 'id': o.id, - 'dashboard': o.dashboard_link(), - 'title': o.dashboard_title, - 'url': o.url, - 'dttm': o.changed_on, - } for o in qry.all()] + 'id': o.id, + 'dashboard': o.dashboard_link(), + 'title': o.dashboard_title, + 'url': o.url, + 'dttm': o.changed_on, + } for o in qry.all()] return json_success( json.dumps(payload, default=utils.json_int_dttm_ser)) @@ -1546,20 +1545,20 @@ def created_slices(self, user_id): Slice = models.Slice # noqa qry = ( db.session.query(Slice) - .filter( + .filter( sqla.or_( Slice.created_by_fk == user_id, Slice.changed_by_fk == user_id, - ) + ) ) - .order_by(Slice.changed_on.desc()) + .order_by(Slice.changed_on.desc()) ) payload = [{ - 'id': o.id, - 'title': o.slice_name, - 'url': o.slice_url, - 'dttm': o.changed_on, - } for o in qry.all()] + 'id': o.id, + 'title': o.slice_name, + 'url': o.slice_url, + 'dttm': o.changed_on, + } for o in qry.all()] return json_success( json.dumps(payload, default=utils.json_int_dttm_ser)) @@ -1573,15 +1572,15 @@ def fave_slices(self, user_id): models.Slice, models.FavStar.dttm, ) - .join( + .join( models.FavStar, sqla.and_( models.FavStar.user_id == int(user_id), models.FavStar.class_name == 'slice', models.Slice.id == models.FavStar.obj_id, - ) + ) ) - .order_by( + .order_by( models.FavStar.dttm.desc() ) ) @@ -1626,8 +1625,8 @@ def warm_up_cache(self): SqlaTable = ConnectorRegistry.sources['table'] table = ( session.query(SqlaTable) - .join(models.Database) - .filter( + .join(models.Database) + .filter( models.Database.database_name == db_name or SqlaTable.table_name == table_name) ).first() @@ -1784,8 +1783,8 @@ def sqllab_viz(self): SqlaTable = ConnectorRegistry.sources['table'] table = ( db.session.query(SqlaTable) - .filter_by(table_name=table_name) - .first() + .filter_by(table_name=table_name) + .first() ) if not table: table = SqlaTable(table_name=table_name) @@ -1813,8 +1812,8 @@ def sqllab_viz(self): metrics.append(models.SqlMetric( metric_name="{agg}__{column_name}".format(**locals()), expression="COUNT(DISTINCT {column_name})" - .format(**locals()), - )) + .format(**locals()), + )) else: metrics.append(models.SqlMetric( metric_name="{agg}__{column_name}".format(**locals()), @@ -1881,7 +1880,7 @@ def table(self, database_id, table_name, schema): 'keys': [ k for k in keys if col['name'] in k.get('column_names') - ], + ], }) tbl = { 'name': table_name, @@ -1952,7 +1951,7 @@ def results(self, key): return json_error_response(get_datasource_access_error_msg( '{}'.format(rejected_tables))) - payload = zlib.decompress(blob) + payload = utils.(blob) display_limit = app.config.get('DISPLAY_SQL_MAX_ROW', None) if display_limit: payload_json = json.loads(payload) @@ -2039,9 +2038,9 @@ def sql_json(self): with utils.timeout( seconds=SQLLAB_TIMEOUT, error_message=( - "The query exceeded the {SQLLAB_TIMEOUT} seconds " - "timeout. You may want to run your query as a " - "`CREATE TABLE AS` to prevent timeouts." + "The query exceeded the {SQLLAB_TIMEOUT} seconds " + "timeout. You may want to run your query as a " + "`CREATE TABLE AS` to prevent timeouts." ).format(**locals())): data = sql_lab.get_sql_results(query_id, return_results=True) except Exception as e: @@ -2056,8 +2055,8 @@ def csv(self, client_id): """Download the query results as csv.""" query = ( db.session.query(models.Query) - .filter_by(client_id=client_id) - .one() + .filter_by(client_id=client_id) + .one() ) rejected_tables = self.rejected_datasources( @@ -2069,7 +2068,7 @@ def csv(self, client_id): if results_backend and query.results_key: blob = results_backend.get(query.results_key) if blob: - json_payload = zlib.decompress(blob) + json_payload = utils.zlib_uncompress_to_string(blob) obj = json.loads(json_payload) columns = [c['name'] for c in obj['columns']] df = pd.DataFrame.from_records(obj['data'], columns=columns) @@ -2093,8 +2092,8 @@ def fetch_datasource_metadata(self): datasource_class = ConnectorRegistry.sources[datasource_type] datasource = ( db.session.query(datasource_class) - .filter_by(id=int(datasource_id)) - .first() + .filter_by(id=int(datasource_id)) + .first() ) # Check if datasource exists @@ -2122,11 +2121,11 @@ def queries(self, last_updated_ms): sql_queries = ( db.session.query(models.Query) - .filter( + .filter( models.Query.user_id == g.user.get_id(), models.Query.changed_on >= last_updated_dt, - ) - .all() + ) + .all() ) dict_queries = {q.client_id: q.to_dict() for q in sql_queries} return json_success( @@ -2172,8 +2171,8 @@ def search_queries(self): query_limit = config.get('QUERY_SEARCH_LIMIT', 1000) sql_queries = ( query.order_by(models.Query.start_time.asc()) - .limit(query_limit) - .all() + .limit(query_limit) + .all() ) dict_queries = [q.to_dict() for q in sql_queries] @@ -2230,8 +2229,8 @@ def profile(self, username): username = g.user.username user = ( db.session.query(ab_models.User) - .filter_by(username=username) - .one() + .filter_by(username=username) + .one() ) roles = {} from collections import defaultdict @@ -2247,7 +2246,7 @@ def profile(self, username): roles[role.name] = [ [perm.permission.name, perm.view_menu.name] for perm in role.permissions - ] + ] payload = { 'user': { 'username': user.username, diff --git a/tests/utils_tests.py b/tests/utils_tests.py index 231f03b84fe8..ef21f3aba0b3 100644 --- a/tests/utils_tests.py +++ b/tests/utils_tests.py @@ -1,7 +1,8 @@ from datetime import datetime, date, timedelta, time from decimal import Decimal from superset.utils import ( - json_int_dttm_ser, json_iso_dttm_ser, base_json_conv, parse_human_timedelta + json_int_dttm_ser, json_iso_dttm_ser, base_json_conv, parse_human_timedelta, zlib_compress, + zlib_uncompress_to_string ) import unittest import uuid @@ -45,3 +46,14 @@ def test_base_json_conv(self): def test_parse_human_timedelta(self, mock_now): mock_now.return_value = datetime(2016, 12, 1) self.assertEquals(parse_human_timedelta('now'), timedelta(0)) + + def test_zlib_compression(self): + json_str = """{"test": 1}""" + blob = zlib_compress(json_str) + got_str = zlib_uncompress_to_string(blob) + self.assertEquals(json_str, got_str) + + byte_str = b"""{"test": 1}""" + blob = zlib_compress(byte_str) + got_str = zlib_uncompress_to_string(blob) + self.assertEquals(json_str, got_str)