Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion fe/fe-core/src/main/java/org/apache/doris/policy/Policy.java
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ public static Policy fromCreateStmt(CreatePolicyStmt stmt) throws AnalysisExcept
}
return new RowPolicy(policyId, stmt.getPolicyName(), stmt.getTableName().getCtl(),
stmt.getTableName().getDb(), stmt.getTableName().getTbl(), userIdent, stmt.getRoleName(),
stmt.getOrigStmt().originStmt, stmt.getFilterType(), stmt.getWherePredicate());
stmt.getOrigStmt().originStmt, stmt.getOrigStmt().idx, stmt.getFilterType(),
stmt.getWherePredicate());
default:
throw new AnalysisException("Unknown policy type: " + stmt.getType());
}
Expand Down
20 changes: 16 additions & 4 deletions fe/fe-core/src/main/java/org/apache/doris/policy/RowPolicy.java
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ public class RowPolicy extends Policy implements RowFilterPolicy {
**/
@SerializedName(value = "originStmt")
private String originStmt;
@SerializedName(value = "stmtIdx")
private int stmtIdx;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what will stmtIdx be when do deserialization if no stmtIdx in json

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will be default 0,we only support stmtIdx is 0


private Expr wherePredicate = null;

Expand All @@ -123,7 +125,7 @@ public RowPolicy() {
* @param wherePredicate where predicate
*/
public RowPolicy(long policyId, final String policyName, long dbId, UserIdentity user, String roleName,
String originStmt,
String originStmt, int stmtIdx,
final long tableId, final FilterType filterType, final Expr wherePredicate) {
super(policyId, PolicyTypeEnum.ROW, policyName);
this.user = user;
Expand All @@ -132,12 +134,13 @@ public RowPolicy(long policyId, final String policyName, long dbId, UserIdentity
this.tableId = tableId;
this.filterType = filterType;
this.originStmt = originStmt;
this.stmtIdx = stmtIdx;
this.wherePredicate = wherePredicate;
}

public RowPolicy(long policyId, final String policyName, String ctlName, String dbName, String tableName,
UserIdentity user, String roleName,
String originStmt, final FilterType filterType, final Expr wherePredicate) {
String originStmt, int stmtIdx, final FilterType filterType, final Expr wherePredicate) {
super(policyId, PolicyTypeEnum.ROW, policyName);
this.user = user;
this.roleName = roleName;
Expand All @@ -146,6 +149,7 @@ public RowPolicy(long policyId, final String policyName, String ctlName, String
this.tableName = tableName;
this.filterType = filterType;
this.originStmt = originStmt;
this.stmtIdx = stmtIdx;
this.wherePredicate = wherePredicate;
}

Expand All @@ -166,16 +170,20 @@ public void gsonPostProcess() throws IOException {
try {
SqlScanner input = new SqlScanner(new StringReader(originStmt), 0L);
SqlParser parser = new SqlParser(input);
CreatePolicyStmt stmt = (CreatePolicyStmt) SqlParserUtils.getFirstStmt(parser);
CreatePolicyStmt stmt = (CreatePolicyStmt) SqlParserUtils.getStmt(parser, stmtIdx);
wherePredicate = stmt.getWherePredicate();
} catch (Exception e) {
throw new IOException("table policy parse originStmt error", e);
String errorMsg = String.format("table policy parse originStmt error, originStmt: %s, stmtIdx: %s.",
originStmt, stmtIdx);
// Only print logs to avoid cluster failure to start
LOG.warn(errorMsg, e);
}
}

@Override
public RowPolicy clone() {
return new RowPolicy(this.id, this.policyName, this.dbId, this.user, this.roleName, this.originStmt,
this.stmtIdx,
this.tableId,
this.filterType, this.wherePredicate);
}
Expand Down Expand Up @@ -218,6 +226,10 @@ public boolean isInvalid() {
public Expression getFilterExpression() throws AnalysisException {
NereidsParser nereidsParser = new NereidsParser();
String sql = getOriginStmt();
if (getStmtIdx() != 0) {
// Under normal circumstances, the index will only be equal to 0
throw new AnalysisException("Invalid row policy [" + getPolicyIdent() + "], " + sql);
}
CreatePolicyCommand command = (CreatePolicyCommand) nereidsParser.parseSingle(sql);
Optional<Expression> wherePredicate = command.getWherePredicate();
if (!wherePredicate.isPresent()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ public void executeQuery(MysqlCommand mysqlCommand, String originStmt) throws Ex
}

StatementBase parsedStmt = stmts.get(i);
parsedStmt.setOrigStmt(new OriginStatement(convertedStmt, i));
parsedStmt.setOrigStmt(new OriginStatement(auditStmt, usingOrigSingleStmt ? 0 : i));
parsedStmt.setUserInfo(ctx.getCurrentUserIdentity());
executor = new StmtExecutor(ctx, parsedStmt);
executor.getProfile().getSummaryProfile().setParseSqlStartTime(parseSqlStartTime);
Expand Down
21 changes: 20 additions & 1 deletion fe/fe-core/src/test/java/org/apache/doris/policy/PolicyTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.doris.common.DdlException;
import org.apache.doris.common.ExceptionChecker;
import org.apache.doris.common.FeConstants;
import org.apache.doris.persist.gson.GsonUtils;
import org.apache.doris.utframe.TestWithFeService;

import com.google.common.collect.Lists;
Expand Down Expand Up @@ -341,7 +342,7 @@ public void testReadWrite() throws IOException, AnalysisException {
FilterType filterType = FilterType.PERMISSIVE;
Expr wherePredicate = null;

Policy rowPolicy = new RowPolicy(10000, policyName, dbId, user, null, originStmt, tableId, filterType,
Policy rowPolicy = new RowPolicy(10000, policyName, dbId, user, null, originStmt, 0, tableId, filterType,
wherePredicate);

ByteArrayOutputStream emptyOutputStream = new ByteArrayOutputStream();
Expand All @@ -364,4 +365,22 @@ public void testReadWrite() throws IOException, AnalysisException {
Assertions.assertEquals(tableId, newRowPolicy.getTableId());
Assertions.assertEquals(filterType, newRowPolicy.getFilterType());
}

@Test
public void testCompatibility() {
String s1 = "{\n"
+ " \"clazz\": \"RowPolicy\",\n"
+ " \"roleName\": \"role1\",\n"
+ " \"dbId\": 2,\n"
+ " \"tableId\": 2,\n"
+ " \"filterType\": \"PERMISSIVE\",\n"
+ " \"originStmt\": \"CREATE ROW POLICY test_row_policy ON test.table1 AS PERMISSIVE TO test_policy USING (k1 \\u003d 1)\",\n"
+ " \"id\": 1,\n"
+ " \"type\": \"ROW\",\n"
+ " \"policyName\": \"cc\",\n"
+ " \"version\": 0\n"
+ "}";
RowPolicy rowPolicy = GsonUtils.GSON.fromJson(s1, RowPolicy.class);
Assertions.assertEquals(rowPolicy.getStmtIdx(), 0);
}
}