Skip to content
Merged
13 changes: 13 additions & 0 deletions ddl/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import (
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/meta"
"github.com/pingcap/tidb/meta/autoid"
"github.com/pingcap/tidb/planner/core"
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
Expand Down Expand Up @@ -5360,6 +5361,18 @@ func (s *testSerialDBSuite) TestAlterShardRowIDBits(c *C) {
c.Assert(err.Error(), Equals, "[autoid:1467]Failed to read auto-increment value from storage engine")
}

func (s *testSerialDBSuite) TestShardRowIDBitsOnTemporaryTable(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec("drop table if exists shard_row_id_temporary")
_, err := tk.Exec("create global temporary table shard_row_id_temporary (a int) shard_row_id_bits = 5 on commit delete rows;")
c.Assert(err.Error(), Equals, core.ErrOptOnTemporaryTable.GenWithStackByArgs("shard_row_id_bits").Error())
tk.MustExec("create global temporary table shard_row_id_temporary (a int) on commit delete rows;")
defer tk.MustExec("drop table if exists shard_row_id_temporary")
_, err = tk.Exec("alter table shard_row_id_temporary shard_row_id_bits = 4;")
c.Assert(err.Error(), Equals, ddl.ErrOptOnTemporaryTable.GenWithStackByArgs("shard_row_id_bits").Error())
}

// port from mysql
// https://github.com/mysql/mysql-server/blob/124c7ab1d6f914637521fd4463a993aa73403513/mysql-test/t/lock.test
func (s *testDBSuite2) TestLock(c *C) {
Expand Down
3 changes: 3 additions & 0 deletions ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2647,6 +2647,9 @@ func (d *ddl) ShardRowID(ctx sessionctx.Context, tableIdent ast.Ident, uVal uint
if err != nil {
return errors.Trace(err)
}
if t.Meta().TempTableType != model.TempTableNone {
return ErrOptOnTemporaryTable.GenWithStackByArgs("shard_row_id_bits")
}
if uVal == t.Meta().ShardRowIDBits {
// Nothing need to do.
return nil
Expand Down
11 changes: 11 additions & 0 deletions ddl/serial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,17 @@ func (s *testSerialSuite) TestTableLocksEnable(c *C) {
checkTableLock(c, tk.Se, "test", "t1", model.TableLockNone)
}

func (s *testSerialDBSuite) TestAutoRandomOnTemporaryTable(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec("drop table if exists auto_random_temporary")
_, err := tk.Exec("create global temporary table auto_random_temporary (a bigint primary key auto_random(3), b varchar(255)) on commit delete rows;")
c.Assert(err.Error(), Equals, core.ErrOptOnTemporaryTable.GenWithStackByArgs("auto_random").Error())
tk.MustExec("set @@tidb_enable_noop_functions = 1")
_, err = tk.Exec("create temporary table t(a bigint key auto_random);")
c.Assert(err.Error(), Equals, core.ErrOptOnTemporaryTable.GenWithStackByArgs("auto_random").Error())
}

func (s *testSerialDBSuite) TestAutoRandom(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("create database if not exists auto_random_db")
Expand Down
12 changes: 8 additions & 4 deletions planner/core/planbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -3399,29 +3399,33 @@ func (b *PlanBuilder) buildDDL(ctx context.Context, node ast.DDLNode) (Plan, err
authErr = ErrTableaccessDenied.GenWithStackByArgs("ALTER", b.ctx.GetSessionVars().User.AuthUsername,
b.ctx.GetSessionVars().User.AuthHostname, v.Table.Name.L)
}
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.AlterPriv, v.Table.Schema.L,
dbName := v.Table.Schema.L
if dbName == "" {
dbName = b.ctx.GetSessionVars().CurrentDB
}
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.AlterPriv, dbName,
v.Table.Name.L, "", authErr)
for _, spec := range v.Specs {
if spec.Tp == ast.AlterTableRenameTable || spec.Tp == ast.AlterTableExchangePartition {
if b.ctx.GetSessionVars().User != nil {
authErr = ErrTableaccessDenied.GenWithStackByArgs("DROP", b.ctx.GetSessionVars().User.AuthUsername,
b.ctx.GetSessionVars().User.AuthHostname, v.Table.Name.L)
}
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.DropPriv, v.Table.Schema.L,
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.DropPriv, dbName,
v.Table.Name.L, "", authErr)

if b.ctx.GetSessionVars().User != nil {
authErr = ErrTableaccessDenied.GenWithStackByArgs("CREATE", b.ctx.GetSessionVars().User.AuthUsername,
b.ctx.GetSessionVars().User.AuthHostname, spec.NewTable.Name.L)
}
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.CreatePriv, spec.NewTable.Schema.L,
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.CreatePriv, dbName,
spec.NewTable.Name.L, "", authErr)

if b.ctx.GetSessionVars().User != nil {
authErr = ErrTableaccessDenied.GenWithStackByArgs("INSERT", b.ctx.GetSessionVars().User.AuthUsername,
b.ctx.GetSessionVars().User.AuthHostname, spec.NewTable.Name.L)
}
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.InsertPriv, spec.NewTable.Schema.L,
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.InsertPriv, dbName,
spec.NewTable.Name.L, "", authErr)
} else if spec.Tp == ast.AlterTableDropPartition {
if b.ctx.GetSessionVars().User != nil {
Expand Down
16 changes: 14 additions & 2 deletions planner/core/preprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,14 @@ func (p *preprocessor) checkCreateTableGrammar(stmt *ast.CreateTableStmt) {
return
}
}
if stmt.TemporaryKeyword != ast.TemporaryNone {
for _, opt := range stmt.Options {
if opt.Tp == ast.TableOptionShardRowID {
p.err = ErrOptOnTemporaryTable.GenWithStackByArgs("shard_row_id_bits")
return
}
}
}
tName := stmt.Table.Name.String()
if isIncorrectName(tName) {
p.err = ddl.ErrWrongTableName.GenWithStackByArgs(tName)
Expand All @@ -656,7 +664,7 @@ func (p *preprocessor) checkCreateTableGrammar(stmt *ast.CreateTableStmt) {
p.err = err
return
}
isPrimary, err := checkColumnOptions(colDef.Options)
isPrimary, err := checkColumnOptions(stmt.TemporaryKeyword != ast.TemporaryNone, colDef.Options)
Comment thread
Howie59 marked this conversation as resolved.
if err != nil {
p.err = err
return
Expand Down Expand Up @@ -813,7 +821,7 @@ func isTableAliasDuplicate(node ast.ResultSetNode, tableAliases map[string]inter
return nil
}

func checkColumnOptions(ops []*ast.ColumnOption) (int, error) {
func checkColumnOptions(isTempTable bool, ops []*ast.ColumnOption) (int, error) {
isPrimary, isGenerated, isStored := 0, 0, false

for _, op := range ops {
Expand All @@ -823,6 +831,10 @@ func checkColumnOptions(ops []*ast.ColumnOption) (int, error) {
case ast.ColumnOptionGenerated:
isGenerated = 1
isStored = op.Stored
case ast.ColumnOptionAutoRandom:
if isTempTable {
return isPrimary, ErrOptOnTemporaryTable.GenWithStackByArgs("auto_random")
}
}
}

Expand Down