diff --git a/be/src/exec/mysql_scan_node.cpp b/be/src/exec/mysql_scan_node.cpp index 2da3507177512e..1fecb87822f6cf 100644 --- a/be/src/exec/mysql_scan_node.cpp +++ b/be/src/exec/mysql_scan_node.cpp @@ -116,7 +116,8 @@ Status MysqlScanNode::open(RuntimeState* state) { RETURN_IF_CANCELLED(state); SCOPED_TIMER(_runtime_profile->total_time_counter()); RETURN_IF_ERROR(_mysql_scanner->open()); - RETURN_IF_ERROR(_mysql_scanner->query(_table_name, _columns, _filters)); + RETURN_IF_ERROR(_mysql_scanner->query(_table_name, _columns, _filters, _limit)); + // check materialize slot num int materialize_num = 0; @@ -161,11 +162,6 @@ Status MysqlScanNode::get_next(RuntimeState* state, RowBatch* row_batch, bool* e SCOPED_TIMER(_runtime_profile->total_time_counter()); SCOPED_TIMER(materialize_tuple_timer()); - if (reached_limit()) { - *eos = true; - return Status::OK(); - } - // create new tuple buffer for row_batch int tuple_buffer_size = row_batch->capacity() * _tuple_desc->byte_size(); void* tuple_buffer = _tuple_pool->allocate(tuple_buffer_size); @@ -181,11 +177,10 @@ Status MysqlScanNode::get_next(RuntimeState* state, RowBatch* row_batch, bool* e while (true) { RETURN_IF_CANCELLED(state); - if (reached_limit() || row_batch->is_full()) { + if (row_batch->is_full()) { // hang on to last allocated chunk in pool, we'll keep writing into it in the // next get_next() call row_batch->tuple_data_pool()->acquire_data(_tuple_pool.get(), !reached_limit()); - *eos = reached_limit(); return Status::OK(); } diff --git a/be/src/exec/mysql_scanner.cpp b/be/src/exec/mysql_scanner.cpp index 274650b7111f93..7cf5ba9c024b42 100644 --- a/be/src/exec/mysql_scanner.cpp +++ b/be/src/exec/mysql_scanner.cpp @@ -111,7 +111,7 @@ Status MysqlScanner::query(const std::string& query) { } Status MysqlScanner::query(const std::string& table, const std::vector& fields, - const std::vector& filters) { + const std::vector& filters, const uint64_t limit) { if (!_is_open) { return Status::InternalError("Query before open."); } @@ -140,6 +140,10 @@ Status MysqlScanner::query(const std::string& table, const std::vector& fields, - const std::vector& filters); + const std::vector& filters, const uint64_t limit); Status get_next_row(char** *buf, unsigned long** lengths, bool* eos); int field_num() const { diff --git a/be/src/exec/odbc_scan_node.cpp b/be/src/exec/odbc_scan_node.cpp index 69949e4b41256c..14ea0eb1112e57 100644 --- a/be/src/exec/odbc_scan_node.cpp +++ b/be/src/exec/odbc_scan_node.cpp @@ -34,9 +34,9 @@ OdbcScanNode::OdbcScanNode(ObjectPool* pool, const TPlanNode& tnode, : ScanNode(pool, tnode, descs), _is_init(false), _table_name(tnode.odbc_scan_node.table_name), + _connect_string(std::move(tnode.odbc_scan_node.connect_string)), + _query_string(std::move(tnode.odbc_scan_node.query_string)), _tuple_id(tnode.odbc_scan_node.tuple_id), - _columns(tnode.odbc_scan_node.columns), - _filters(tnode.odbc_scan_node.filters), _tuple_desc(nullptr) { } @@ -63,21 +63,9 @@ Status OdbcScanNode::prepare(RuntimeState* state) { } _slot_num = _tuple_desc->slots().size(); - // get odbc table info - const ODBCTableDescriptor* odbc_table = - static_cast(_tuple_desc->table_desc()); - if (NULL == odbc_table) { - return Status::InternalError("odbc table pointer is NULL."); - } - - _odbc_param.host = odbc_table->host(); - _odbc_param.port = odbc_table->port(); - _odbc_param.user = odbc_table->user(); - _odbc_param.passwd = odbc_table->passwd(); - _odbc_param.db = odbc_table->db(); - _odbc_param.drivier = odbc_table->driver(); - _odbc_param.type = odbc_table->type(); + _odbc_param.connect_string = std::move(_connect_string); + _odbc_param.query_string = std::move(_query_string); _odbc_param.tuple_desc = _tuple_desc; _odbc_scanner.reset(new (std::nothrow)ODBCScanner(_odbc_param)); @@ -119,7 +107,7 @@ Status OdbcScanNode::open(RuntimeState* state) { RETURN_IF_CANCELLED(state); SCOPED_TIMER(_runtime_profile->total_time_counter()); RETURN_IF_ERROR(_odbc_scanner->open()); - RETURN_IF_ERROR(_odbc_scanner->query(_table_name, _columns, _filters)); + RETURN_IF_ERROR(_odbc_scanner->query()); // check materialize slot num return Status::OK(); @@ -153,11 +141,6 @@ Status OdbcScanNode::get_next(RuntimeState* state, RowBatch* row_batch, bool* eo SCOPED_TIMER(_runtime_profile->total_time_counter()); SCOPED_TIMER(materialize_tuple_timer()); - if (reached_limit()) { - *eos = true; - return Status::OK(); - } - // create new tuple buffer for row_batch int tuple_buffer_size = row_batch->capacity() * _tuple_desc->byte_size(); void* tuple_buffer = _tuple_pool->allocate(tuple_buffer_size); @@ -173,11 +156,10 @@ Status OdbcScanNode::get_next(RuntimeState* state, RowBatch* row_batch, bool* eo while (true) { RETURN_IF_CANCELLED(state); - if (reached_limit() || row_batch->is_full()) { + if (row_batch->is_full()) { // hang on to last allocated chunk in pool, we'll keep writing into it in the // next get_next() call row_batch->tuple_data_pool()->acquire_data(_tuple_pool.get(), !reached_limit()); - *eos = reached_limit(); return Status::OK(); } @@ -238,8 +220,6 @@ Status OdbcScanNode::get_next(RuntimeState* state, RowBatch* row_batch, bool* eo _tuple = reinterpret_cast(new_tuple); } } - - return Status::OK(); } Status OdbcScanNode::close(RuntimeState* state) { diff --git a/be/src/exec/odbc_scan_node.h b/be/src/exec/odbc_scan_node.h index 48e28aab87677a..37b9bece8aa17f 100644 --- a/be/src/exec/odbc_scan_node.h +++ b/be/src/exec/odbc_scan_node.h @@ -68,13 +68,12 @@ class OdbcScanNode : public ScanNode { // Name of Odbc table std::string _table_name; + std::string _connect_string; + + std::string _query_string; // Tuple id resolved in prepare() to set _tuple_desc; TupleId _tuple_id; - // select columns - std::vector _columns; - // where clause - std::vector _filters; // Descriptor of tuples read from ODBC table. const TupleDescriptor* _tuple_desc; diff --git a/be/src/exec/odbc_scanner.cpp b/be/src/exec/odbc_scanner.cpp index 94f4876e1c63d0..fbaccc62d44395 100644 --- a/be/src/exec/odbc_scanner.cpp +++ b/be/src/exec/odbc_scanner.cpp @@ -47,8 +47,8 @@ static std::u16string utf8_to_wstring(const std::string& str) { namespace doris { ODBCScanner::ODBCScanner(const ODBCScannerParam& param) - : _connect_string(build_connect_string(param)), - _type(param.type), + : _connect_string(param.connect_string), + _sql_str(param.query_string), _tuple_desc(param.tuple_desc), _is_open(false), _field_num(0), @@ -97,7 +97,7 @@ Status ODBCScanner::open() { return Status::OK(); } -Status ODBCScanner::query(const std::string& query) { +Status ODBCScanner::query() { if (!_is_open) { return Status::InternalError( "Query before open."); } @@ -106,13 +106,13 @@ Status ODBCScanner::query(const std::string& query) { ODBC_DISPOSE(_dbc, SQL_HANDLE_DBC, SQLAllocHandle(SQL_HANDLE_STMT, _dbc, &_stmt), "alloc statement"); // Translate utf8 string to utf16 to use unicode codeing - auto wquery = utf8_to_wstring(query); + auto wquery = utf8_to_wstring(_sql_str); ODBC_DISPOSE(_stmt, SQL_HANDLE_STMT, SQLExecDirectW(_stmt, (SQLWCHAR*)(wquery.c_str()), SQL_NTS), "exec direct"); // How many columns are there */ ODBC_DISPOSE(_stmt, SQL_HANDLE_STMT, SQLNumResultCols(_stmt, &_field_num), "count num colomn"); - LOG(INFO) << "execute success:" << query << " column count:" << _field_num; + LOG(INFO) << "execute success:" << _sql_str << " column count:" << _field_num; // check materialize num equal _field_num int materialize_num = 0; @@ -145,39 +145,6 @@ Status ODBCScanner::query(const std::string& query) { return Status::OK(); } -Status ODBCScanner::query(const std::string& table, const std::vector& fields, - const std::vector& filters) { - if (!_is_open) { - return Status::InternalError("Query before open."); - } - - _sql_str = "SELECT "; - - for (int i = 0; i < fields.size(); ++i) { - if (0 != i) { - _sql_str += ","; - } - - _sql_str += fields[i]; - } - - _sql_str += " FROM " + table; - - if (!filters.empty()) { - _sql_str += " WHERE "; - - for (int i = 0; i < filters.size(); ++i) { - if (0 != i) { - _sql_str += " AND"; - } - - _sql_str += " (" + filters[i] + ") "; - } - } - - return query(_sql_str); -} - Status ODBCScanner::get_next_row(bool* eos) { if (!_is_open) { return Status::InternalError("GetNextRow before open."); @@ -240,23 +207,4 @@ std::string ODBCScanner::handle_diagnostic_record(SQLHANDLE hHandle, return diagnostic_msg; } -std::string ODBCScanner::build_connect_string(const ODBCScannerParam& param) { - // different database have different connection string - // oracle connect string - if (param.type == TOdbcTableType::ORACLE) { - boost::format connect_string("Driver=%s;Dbq=//%s:%s/%s;DataBase=%s;Uid=%s;Pwd=%s;charset=%s"); - connect_string % param.drivier % param.host % param.port % param.db % param.db % param.user % param.passwd % - param.charest; - - return connect_string.str(); - } else if (param.type == TOdbcTableType::MYSQL) { - boost::format connect_string("Driver=%s;Server=%s;Port=%s;DataBase=%s;Uid=%s;Pwd=%s;charset=%s"); - connect_string % param.drivier % param.host % param.port % param.db % param.user % param.passwd % - param.charest; - return connect_string.str(); - } - - return ""; -} - } diff --git a/be/src/exec/odbc_scanner.h b/be/src/exec/odbc_scanner.h index 50f7d84b337950..14c84a9df0b6d8 100644 --- a/be/src/exec/odbc_scanner.h +++ b/be/src/exec/odbc_scanner.h @@ -32,15 +32,9 @@ namespace doris { struct ODBCScannerParam { - std::string host; - std::string port; - std::string user; - std::string passwd; - std::string db; - std::string drivier; - std::string charest = "utf8"; - - TOdbcTableType::type type; + std::string connect_string; + std::string query_string; + const TupleDescriptor* tuple_desc; }; @@ -67,11 +61,8 @@ class ODBCScanner { Status open(); - Status query(const std::string& query); - - // query for DORIS - Status query(const std::string& table, const std::vector& fields, - const std::vector& filters); + // query for ODBC table + Status query(); Status get_next_row(bool* eos); @@ -80,8 +71,6 @@ class ODBCScanner { } private: - static std::string build_connect_string(const ODBCScannerParam& param); - static Status error_status(const std::string& prefix, const std::string& error_msg); static std::string handle_diagnostic_record (SQLHANDLE hHandle, @@ -90,7 +79,6 @@ class ODBCScanner { std::string _connect_string; std::string _sql_str; - TOdbcTableType::type _type; const TupleDescriptor* _tuple_desc; bool _is_open; diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/OdbcTable.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/OdbcTable.java index a79fb203212c3b..abed36cd762973 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/OdbcTable.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/OdbcTable.java @@ -59,8 +59,6 @@ public class OdbcTable extends Table { static { Map tempMap = new HashMap<>(); tempMap.put("oracle", TOdbcTableType.ORACLE); - // we will support mysql driver in the future after we solve the core problem of - // driver and static library tempMap.put("mysql", TOdbcTableType.MYSQL); TABLE_TYPE_MAP = Collections.unmodifiableMap(tempMap); } @@ -244,6 +242,36 @@ public String getOdbcTableTypeName() { return getPropertyFromResource(ODBC_TYPE); } + public String getConnectString() { + String connectString = ""; + // different database have different connection string + switch (getOdbcTableType()) { + case ORACLE: + connectString = String.format("Driver=%s;Dbq=//%s:%s/%s;DataBase=%s;Uid=%s;Pwd=%s;charset=%s", + getOdbcDriver(), + getHost(), + getPort(), + getOdbcDatabaseName(), + getOdbcDatabaseName(), + getUserName(), + getPasswd(), + "utf8"); + break; + case MYSQL: + connectString = String.format("Driver=%s;Server=%s;Port=%s;DataBase=%s;Uid=%s;Pwd=%s;charset=%s", + getOdbcDriver(), + getHost(), + getPort(), + getOdbcDatabaseName(), + getUserName(), + getPasswd(), + "utf8"); + break; + default: + } + return connectString; + } + public TOdbcTableType getOdbcTableType() { return TABLE_TYPE_MAP.get(getOdbcTableTypeName()); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/MysqlScanNode.java b/fe/fe-core/src/main/java/org/apache/doris/planner/MysqlScanNode.java index c74bd4c82b3780..e54875e035a15a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/planner/MysqlScanNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/MysqlScanNode.java @@ -92,6 +92,11 @@ private String getMysqlQueryStr() { sql.append(Joiner.on(") AND (").join(filters)); sql.append(")"); } + + if (limit != -1) { + sql.append(" LIMIT " + limit); + } + return sql.toString(); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/OdbcScanNode.java b/fe/fe-core/src/main/java/org/apache/doris/planner/OdbcScanNode.java index 66928d23dc864d..6a37415b7e67f7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/planner/OdbcScanNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/OdbcScanNode.java @@ -78,7 +78,7 @@ private static boolean needPushDown(TOdbcTableType tableType, Expr expr) { private final List columns = new ArrayList(); private final List filters = new ArrayList(); private String tblName; - private String driver; + private String connectString; private TOdbcTableType odbcType; /** @@ -86,7 +86,7 @@ private static boolean needPushDown(TOdbcTableType tableType, Expr expr) { */ public OdbcScanNode(PlanNodeId id, TupleDescriptor desc, OdbcTable tbl) { super(id, desc, "SCAN ODBC"); - driver = tbl.getOdbcDriver(); + connectString = tbl.getConnectString(); odbcType = tbl.getOdbcTableType(); tblName = databaseProperName(odbcType, tbl.getOdbcTableName()); } @@ -109,12 +109,24 @@ public void finalize(Analyzer analyzer) throws UserException { protected String getNodeExplainString(String prefix, TExplainLevel detailLevel) { StringBuilder output = new StringBuilder(); output.append(prefix).append("TABLE: ").append(tblName).append("\n"); - output.append(prefix).append("Query: ").append(getOdbcQueryStr()).append("\n"); + output.append(prefix).append("TABLE TYPE: ").append(odbcType.toString()).append("\n"); + output.append(prefix).append("QUERY: ").append(getOdbcQueryStr()).append("\n"); return output.toString(); } private String getOdbcQueryStr() { StringBuilder sql = new StringBuilder("SELECT "); + + // Oracle use the where clause to do top n + if (limit != -1 && odbcType == TOdbcTableType.ORACLE) { + filters.add("ROWNUM <= " + limit); + } + + // MSSQL use select top to do top n + if (limit != -1 && odbcType == TOdbcTableType.SQLSERVER) { + sql.append("TOP " + limit + " "); + } + sql.append(Joiner.on(", ").join(columns)); sql.append(" FROM ").append(tblName); @@ -123,6 +135,12 @@ private String getOdbcQueryStr() { sql.append(Joiner.on(") AND (").join(filters)); sql.append(")"); } + + // Other DataBase use limit do top n + if (limit != -1 && (odbcType == TOdbcTableType.MYSQL || odbcType == TOdbcTableType.POSTGRESQL || odbcType == TOdbcTableType.MONGODB) ) { + sql.append(" LIMIT " + limit); + } + return sql.toString(); } @@ -172,10 +190,8 @@ protected void toThrift(TPlanNode msg) { TOdbcScanNode odbcScanNode = new TOdbcScanNode(); odbcScanNode.setTupleId(desc.getId().asInt()); odbcScanNode.setTableName(tblName); - odbcScanNode.setDriver(driver); - odbcScanNode.setType(odbcType); - odbcScanNode.setColumns(columns); - odbcScanNode.setFilters(filters); + odbcScanNode.setConnectString(connectString); + odbcScanNode.setQueryString(getOdbcQueryStr()); msg.odbc_scan_node = odbcScanNode; } diff --git a/fe/fe-core/src/test/java/org/apache/doris/planner/QueryPlanTest.java b/fe/fe-core/src/test/java/org/apache/doris/planner/QueryPlanTest.java index 970a82e8e4838f..6441e9e28c811d 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/planner/QueryPlanTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/planner/QueryPlanTest.java @@ -1128,6 +1128,26 @@ public void testPushDownOfOdbcTable() throws Exception { Assert.assertTrue(!explainString.contains("abs(k1) > 10")); } + @Test + public void testLimitOfExternalTable() throws Exception { + connectContext.setDatabase("default_cluster:test"); + + // ODBC table (MySQL) + String queryStr = "explain select * from odbc_mysql where k1 > 10 and abs(k1) > 10 limit 10"; + String explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, queryStr); + Assert.assertTrue(explainString.contains("LIMIT 10")); + + // ODBC table (Oracle) + queryStr = "explain select * from odbc_oracle where k1 > 10 and abs(k1) > 10 limit 10"; + explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, queryStr); + Assert.assertTrue(explainString.contains("ROWNUM <= 10")); + + // MySQL table + queryStr = "explain select * from mysql_table where k1 > 10 and abs(k1) > 10 limit 10"; + explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, queryStr); + Assert.assertTrue(explainString.contains("LIMIT 10")); + } + @Test public void testPreferBroadcastJoin() throws Exception { connectContext.setDatabase("default_cluster:test"); diff --git a/gensrc/thrift/PlanNodes.thrift b/gensrc/thrift/PlanNodes.thrift index e680bf96741f51..61036447b3ac0b 100644 --- a/gensrc/thrift/PlanNodes.thrift +++ b/gensrc/thrift/PlanNodes.thrift @@ -202,10 +202,16 @@ struct TMySQLScanNode { struct TOdbcScanNode { 1: optional Types.TTupleId tuple_id 2: optional string table_name + + //Deprecated 3: optional string driver 4: optional Types.TOdbcTableType type 5: optional list columns 6: optional list filters + + //Use now + 7: optional string connect_string + 8: optional string query_string }