Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@

package org.apache.doris.analysis;

import org.apache.doris.catalog.Env;
import org.apache.doris.common.ErrorCode;
import org.apache.doris.common.ErrorReport;
import org.apache.doris.common.UserException;
import org.apache.doris.mysql.privilege.PrivPredicate;
import org.apache.doris.qe.ConnectContext;

import lombok.Getter;

Expand Down Expand Up @@ -69,5 +74,11 @@ public List<String> getPartitionNames() {

@Override
public void analyze(Analyzer analyzer) throws UserException {
if (!Env.getCurrentEnv().getAccessManager()
.checkTblPriv(ConnectContext.get(), getDb(), getTbl(), PrivPredicate.LOAD)) {
ErrorReport.reportAnalysisException(ErrorCode.ERR_TABLEACCESS_DENIED_ERROR, "LOAD",
ConnectContext.get().getQualifiedUser(), ConnectContext.get().getRemoteIP(),
getDb() + ": " + getTbl());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ public boolean checkTblPriv(ConnectContext ctx, TableName tableName, PrivPredica

public boolean checkTblPriv(ConnectContext ctx, String qualifiedCtl,
String qualifiedDb, String tbl, PrivPredicate wanted) {
if (ctx.isSkipAuth()) {
return true;
}
return checkTblPriv(ctx.getCurrentUserIdentity(), qualifiedCtl, qualifiedDb, tbl, wanted);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.common.DdlException;
import org.apache.doris.common.ErrorCode;
import org.apache.doris.common.ErrorReport;
import org.apache.doris.common.UserException;
import org.apache.doris.common.util.ProfileManager.ProfileType;
import org.apache.doris.datasource.InternalCatalog;
import org.apache.doris.mysql.privilege.PrivPredicate;
import org.apache.doris.nereids.NereidsPlanner;
import org.apache.doris.nereids.analyzer.UnboundOlapTableSink;
import org.apache.doris.nereids.exceptions.AnalysisException;
Expand Down Expand Up @@ -134,6 +136,16 @@ public void run(ConnectContext ctx, StmtExecutor executor) throws Exception {
Preconditions.checkArgument(plan.isPresent(), "insert into command must contain OlapTableSinkNode");
PhysicalOlapTableSink<?> physicalOlapTableSink = ((PhysicalOlapTableSink<?>) plan.get());

OlapTable targetTable = physicalOlapTableSink.getTargetTable();
// check auth
if (!Env.getCurrentEnv().getAccessManager()
.checkTblPriv(ConnectContext.get(), targetTable.getQualifiedDbName(), targetTable.getName(),
PrivPredicate.LOAD)) {
ErrorReport.reportAnalysisException(ErrorCode.ERR_TABLEACCESS_DENIED_ERROR, "LOAD",
ConnectContext.get().getQualifiedUser(), ConnectContext.get().getRemoteIP(),
targetTable.getQualifiedDbName() + ": " + targetTable.getName());
}

if (isOverwrite) {
dealOverwrite(ctx, executor, physicalOlapTableSink);
return;
Expand Down Expand Up @@ -189,24 +201,29 @@ public void run(ConnectContext ctx, StmtExecutor executor) throws Exception {
* @param ctx ctx
* @param executor executor
* @param physicalOlapTableSink physicalOlapTableSink
*
* @throws Exception Exception
*/
public void dealOverwrite(ConnectContext ctx, StmtExecutor executor,
PhysicalOlapTableSink<?> physicalOlapTableSink) throws Exception {
OlapTable targetTable = physicalOlapTableSink.getTargetTable();
TableName tableName = new TableName(InternalCatalog.INTERNAL_CATALOG_NAME, targetTable.getQualifiedDbName(),
targetTable.getName());
List<String> partitionNames = ((UnboundOlapTableSink<?>) logicalQuery).getPartitions();
if (CollectionUtils.isEmpty(partitionNames)) {
partitionNames = Lists.newArrayList(targetTable.getPartitionNames());
}
List<String> tempPartitionNames = addTempPartition(ctx, tableName, partitionNames);
boolean insertRes = insertInto(ctx, executor, tempPartitionNames, tableName);
if (!insertRes) {
return;
ConnectContext.get().setSkipAuth(true);
try {
List<String> partitionNames = ((UnboundOlapTableSink<?>) logicalQuery).getPartitions();
if (CollectionUtils.isEmpty(partitionNames)) {
partitionNames = Lists.newArrayList(targetTable.getPartitionNames());
}
List<String> tempPartitionNames = addTempPartition(ctx, tableName, partitionNames);
boolean insertRes = insertInto(ctx, executor, tempPartitionNames, tableName);
if (!insertRes) {
return;
}
replacePartition(ctx, tableName, partitionNames, tempPartitionNames);
} finally {
ConnectContext.get().setSkipAuth(false);
}
replacePartition(ctx, tableName, partitionNames, tempPartitionNames);

}

/**
Expand Down
14 changes: 14 additions & 0 deletions fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,12 @@ public enum ConnectType {

private TResultSinkType resultSinkType = TResultSinkType.MYSQL_PROTOCAL;

//internal call like `insert overwrite` need skipAuth
// For example, `insert overwrite` only requires load permission,
// but the internal implementation will call the logic of `AlterTable`.
// In this case, `skipAuth` needs to be set to `true` to skip the permission check of `AlterTable`
private boolean skipAuth = false;

public void setUserQueryTimeout(int queryTimeout) {
if (queryTimeout > 0) {
sessionVariable.setQueryTimeoutS(queryTimeout);
Expand Down Expand Up @@ -903,5 +909,13 @@ public void setInsertGroupCommit(long tableId, Backend backend) {
public Backend getInsertGroupCommit(long tableId) {
return insertGroupCommitTableToBeMap.get(tableId);
}

public boolean isSkipAuth() {
return skipAuth;
}

public void setSkipAuth(boolean skipAuth) {
this.skipAuth = skipAuth;
}
}

20 changes: 13 additions & 7 deletions fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java
Original file line number Diff line number Diff line change
Expand Up @@ -994,6 +994,7 @@ public void analyze(TQueryOptions tQueryOptions) throws UserException, Interrupt
queryStmt.getTables(analyzer, false, tableMap, parentViewNameSet);
} else if (parsedStmt instanceof InsertOverwriteTableStmt) {
InsertOverwriteTableStmt parsedStmt = (InsertOverwriteTableStmt) this.parsedStmt;
parsedStmt.analyze(analyzer);
queryStmt = parsedStmt.getQueryStmt();
queryStmt.getTables(analyzer, false, tableMap, parentViewNameSet);
} else if (parsedStmt instanceof CreateTableAsSelectStmt) {
Expand Down Expand Up @@ -2390,13 +2391,18 @@ private void handleCtasRollback(TableName table) {
}

private void handleIotStmt() {
InsertOverwriteTableStmt iotStmt = (InsertOverwriteTableStmt) this.parsedStmt;
if (iotStmt.getPartitionNames().size() == 0) {
// insert overwrite table
handleOverwriteTable(iotStmt);
} else {
// insert overwrite table with partition
handleOverwritePartition(iotStmt);
ConnectContext.get().setSkipAuth(true);
try {
InsertOverwriteTableStmt iotStmt = (InsertOverwriteTableStmt) this.parsedStmt;
if (iotStmt.getPartitionNames().size() == 0) {
// insert overwrite table
handleOverwriteTable(iotStmt);
} else {
// insert overwrite table with partition
handleOverwritePartition(iotStmt);
}
} finally {
ConnectContext.get().setSkipAuth(false);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

suite('nereids_insert_auth') {
sql 'set enable_nereids_planner=true'
sql 'set enable_fallback_to_original_planner=false'
sql 'set enable_nereids_dml=true'
sql 'set enable_strict_consistency_dml=true'

def db = 'nereids_insert_auth_db'
sql "drop database if exists ${db}"
sql "create database ${db}"
sql "use ${db}"

def t1 = 't1'

sql "drop table if exists ${t1}"

sql """
create table ${t1} (
id int,
c1 bigint
)
distributed by hash(id) buckets 2
properties(
'replication_num'='1'
);
"""

String user = "nereids_insert_auth_user";
String pwd = '123456';
def tokens = context.config.jdbcUrl.split('/')
def url = tokens[0] + "//" + tokens[2] + "/" + "information_schema" + "?"
try_sql("DROP USER ${user}")
sql """CREATE USER '${user}' IDENTIFIED BY '${pwd}'"""

connect(user=user, password="${pwd}", url=url) {
try {
sql """ insert into ${db}.${t1} values (1, 1) """
fail()
} catch (Exception e) {
log.info(e.getMessage())
}
}

sql """GRANT LOAD_PRIV ON ${db}.${t1} TO ${user}"""

connect(user=user, password="${pwd}", url=url) {
try {
sql """ insert into ${db}.${t1} values (1, 1) """
} catch (Exception e) {
log.info(e.getMessage())
fail()
}
}

connect(user=user, password="${pwd}", url=url) {
try {
sql """ insert overwrite table ${db}.${t1} values (2, 2) """
} catch (Exception e) {
log.info(e.getMessage())
fail()
}
}
}