Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions caravel/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,11 @@ def sql_url(self):
def sql_link(self):
return '<a href="{}">SQL</a>'.format(self.sql_url)

@property
def perm(self):
return (
"[{obj.database_name}].(id:{obj.id})").format(obj=self)


class SqlaTable(Model, Queryable, AuditMixinNullable):

Expand Down
6 changes: 5 additions & 1 deletion caravel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def init(caravel):

perms = db.session.query(ab_models.PermissionView).all()
for perm in perms:
if perm.permission.name == 'datasource_access':
if perm.permission.name in ('datasource_access', 'database_access'):
continue
if perm.view_menu and perm.view_menu.name not in (
'UserDBModelView', 'RoleModelView', 'ResetPasswordView',
Expand All @@ -226,6 +226,7 @@ def init(caravel):
'can_edit',
'can_save',
'datasource_access',
'database_access',
'muldelete',
)):
sm.add_permission_role(gamma, perm)
Expand All @@ -239,6 +240,9 @@ def init(caravel):
for table_perm in table_perms:
merge_perm(sm, 'datasource_access', table_perm)

db_perms = [db.perm for db in session.query(models.Database).all()]
for db_perm in db_perms:
merge_perm(sm, 'database_access', db_perm)
init_metrics_perm(caravel)


Expand Down
65 changes: 44 additions & 21 deletions caravel/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ def pre_add(self, db):
db.password = conn.password
conn.password = "X" * 10 if conn.password else None
db.sqlalchemy_uri = str(conn) # hides the password
utils.merge_perm(sm, 'database_access', db.perm)

def pre_update(self, db):
self.pre_add(db)
Expand Down Expand Up @@ -1176,15 +1177,17 @@ def dashboard(**kwargs): # noqa
@expose("/sql/<database_id>/")
@log_this
def sql(self, database_id):
if (
not self.can_access(
'all_datasource_access', 'all_datasource_access')):
flash(
"This view requires the `all_datasource_access` "
"permission", "danger")
return redirect("/tablemodelview/list/")
mydb = db.session.query(
models.Database).filter_by(id=database_id).first()

if not (self.can_access(
'all_datasource_access', 'all_datasource_access') or
self.can_access('database_access', mydb.perm)):
flash(
"This view requires the specific database or "
"`all_datasource_access` permission", "danger"
)
return redirect("/tablemodelview/list/")
engine = mydb.get_sqla_engine()
tables = engine.table_names()

Expand Down Expand Up @@ -1221,6 +1224,18 @@ def select_star(self, database_id, table_name):
mydb = db.session.query(
models.Database).filter_by(id=database_id).first()
t = mydb.get_table(table_name)

# Prevent exposing column fields to users that cannot access DB.
if not (self.can_access(
'all_datasource_access', 'all_datasource_access') or
self.can_access('database_access', mydb.perm) or
self.can_access('datasource_access', t.perm)):
flash(
"This view requires the specific database, table or "
"`all_datasource_access` permission", "danger"
)
return redirect("/tablemodelview/list/")

fields = ", ".join(
[c.name for c in t.columns] or "*")
s = "SELECT\n{}\nFROM {}".format(fields, table_name)
Expand All @@ -1242,22 +1257,26 @@ def runsql(self):
database_id = data.get('database_id')
mydb = session.query(models.Database).filter_by(id=database_id).first()

if (
not self.can_access(
'all_datasource_access', 'all_datasource_access')):
if not (self.can_access(
'all_datasource_access', 'all_datasource_access') or
self.can_access('database_access', mydb.perm)):
raise utils.CaravelSecurityException(_(
"SQL Lab requires the `all_datasource_access` permission"))
"SQL Lab requires the `all_datasource_access` or "
"specific db permission"))

content = ""
if mydb:
eng = mydb.get_sqla_engine()
if limit:
sql = sql.strip().strip(';')
qry = (
select('*')
.select_from(TextAsFrom(text(sql), ['*']).alias('inner_qry'))
.select_from(TextAsFrom(text(sql), ['*'])
.alias('inner_qry'))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this line up with the rest of the calls in the method chaining?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pycharm puts it there. It uses pep8 by default. .alias belongs to the TextAsFrom(text(sql), ['*']), not to the result of the select_from.

.limit(limit)
)
sql = str(qry.compile(eng, compile_kwargs={"literal_binds": True}))
sql = '{}'.format(qry.compile(
eng, compile_kwargs={"literal_binds": True}))
try:
df = pd.read_sql_query(sql=sql, con=eng)
content = df.to_html(
Expand Down Expand Up @@ -1289,11 +1308,12 @@ def sql_json(self):
database_id = request.form.get('database_id')
mydb = session.query(models.Database).filter_by(id=database_id).first()

if (
not self.can_access(
'all_datasource_access', 'all_datasource_access')):
if not (self.can_access(
'all_datasource_access', 'all_datasource_access') or
self.can_access('database_access', mydb.perm)):
raise utils.CaravelSecurityException(_(
"This view requires the `all_datasource_access` permission"))
"SQL Lab requires the `all_datasource_access` or "
"specific DB permission"))

error_msg = ""
if not mydb:
Expand All @@ -1304,10 +1324,12 @@ def sql_json(self):
sql = sql.strip().strip(';')
qry = (
select('*')
.select_from(TextAsFrom(text(sql), ['*']).alias('inner_qry'))
.select_from(TextAsFrom(text(sql), ['*'])
.alias('inner_qry'))
.limit(limit)
)
sql = str(qry.compile(eng, compile_kwargs={"literal_binds": True}))
sql = '{}'.format(qry.compile(
eng, compile_kwargs={"literal_binds": True}))
try:
df = pd.read_sql_query(sql=sql, con=eng)
df = df.fillna(0) # TODO make sure NULL
Expand All @@ -1328,7 +1350,8 @@ def sql_json(self):
'columns': [c for c in df.columns],
'data': df.to_dict(orient='records'),
}
return json.dumps(data, default=utils.json_int_dttm_ser, allow_nan=False)
return json.dumps(
data, default=utils.json_int_dttm_ser, allow_nan=False)

@has_access
@expose("/refresh_datasources/")
Expand All @@ -1342,7 +1365,7 @@ def refresh_datasources(self):
except Exception as e:
flash(
"Error while processing cluster '{}'\n{}".format(
cluster_name, str(e)),
cluster_name, utils.error_msg_from_exception(e)),
"danger")
logging.exception(e)
return redirect('/druidclustermodelview/list/')
Expand Down
45 changes: 39 additions & 6 deletions tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from flask_appbuilder.security.sqla import models as ab_models

import caravel
from caravel import app, db, models, utils, appbuilder
from caravel import app, db, models, utils, appbuilder, sm
from caravel.models import DruidCluster, DruidDatasource

os.environ['CARAVEL_CONFIG'] = 'tests.caravel_test_config'
Expand Down Expand Up @@ -247,8 +247,8 @@ def test_gamma(self):
resp = self.client.get('/dashboardmodelview/list/')
assert "List Dashboard" in resp.data.decode('utf-8')

def run_sql(self, sql):
self.login(username='admin')
def run_sql(self, sql, user_name):
self.login(username=user_name)
dbid = (
db.session.query(models.Database)
.filter_by(database_name="main")
Expand All @@ -258,13 +258,47 @@ def run_sql(self, sql):
'/caravel/sql_json/',
data=dict(database_id=dbid, sql=sql),
)
self.logout()
return json.loads(resp.data.decode('utf-8'))

def test_sql_json_no_access(self):
self.assertRaises(
utils.CaravelSecurityException,
self.run_sql, "SELECT * FROM ab_user", 'gamma')

def test_sql_json_has_access(self):
main_db = (
db.session.query(models.Database).filter_by(database_name="main")
.first()
)
utils.merge_perm(sm, 'database_access', main_db.perm)
db.session.commit()
main_db_permission_view = (
db.session.query(ab_models.PermissionView)
.join(ab_models.ViewMenu)
.filter(ab_models.ViewMenu.name == '[main].(id:1)')
.first()
)
astronaut = sm.add_role("Astronaut")
sm.add_permission_role(astronaut, main_db_permission_view)
# Astronaut role is Gamme + main db permissions
for gamma_perm in sm.find_role('Gamma').permissions:
sm.add_permission_role(astronaut, gamma_perm)

gagarin = appbuilder.sm.find_user('gagarin')
if not gagarin:
appbuilder.sm.add_user(
'gagarin', 'Iurii', 'Gagarin', 'gagarin@cosmos.ussr',
appbuilder.sm.find_role('Astronaut'),
password='general')
data = self.run_sql('SELECT * FROM ab_user', 'gagarin')
assert len(data['data']) > 0

def test_sql_json(self):
data = self.run_sql("SELECT * FROM ab_user")
data = self.run_sql("SELECT * FROM ab_user", 'admin')
assert len(data['data']) > 0

data = self.run_sql("SELECT * FROM unexistant_table")
data = self.run_sql("SELECT * FROM unexistant_table", 'admin')
assert len(data['error']) > 0

def test_public_user_dashboard_access(self):
Expand Down Expand Up @@ -301,7 +335,6 @@ def test_public_user_dashboard_access(self):
data = resp.data.decode('utf-8')
assert "/caravel/dashboard/world_health/" not in data


def test_only_owners_can_save(self):
dash = (
db.session
Expand Down