diff --git a/auth/auth.go b/auth/auth.go index afe0736e0..24506363b 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -15,6 +15,7 @@ package auth import ( "fmt" + "strings" "github.com/pingcap/parser/format" ) @@ -40,19 +41,24 @@ func (user *UserIdentity) Restore(ctx *format.RestoreCtx) error { return nil } -// String converts UserIdentity to the format user@host. +func EscapeAccountName(s string) string { + // We do not have access to the sql_mode here, + // so assume NO_BACKSLASH_ESCAPES in effect, + // since it is still correct if not set. + return "'" + strings.ReplaceAll(s, "'", "''") + "'" +} + +// String converts UserIdentity to the format 'user'@'host'. func (user *UserIdentity) String() string { - // TODO: Escape username and hostname. if user == nil { return "" } - return fmt.Sprintf("%s@%s", user.Username, user.Hostname) + return fmt.Sprintf("%s@%s", EscapeAccountName(user.Username), EscapeAccountName(user.Hostname)) } -// AuthIdentityString returns matched identity in user@host format +// AuthIdentityString returns matched identity in 'user'@'host' format func (user *UserIdentity) AuthIdentityString() string { - // TODO: Escape username and hostname. - return fmt.Sprintf("%s@%s", user.AuthUsername, user.AuthHostname) + return fmt.Sprintf("%s@%s", EscapeAccountName(user.AuthUsername), EscapeAccountName(user.AuthHostname)) } type RoleIdentity struct { @@ -69,8 +75,7 @@ func (role *RoleIdentity) Restore(ctx *format.RestoreCtx) error { return nil } -// String converts UserIdentity to the format user@host. +// String converts UserIdentity to the format 'user'@'host'. func (role *RoleIdentity) String() string { - // TODO: Escape username and hostname. - return fmt.Sprintf("`%s`@`%s`", role.Username, role.Hostname) + return fmt.Sprintf("%s@%s", EscapeAccountName(role.Username), EscapeAccountName(role.Hostname)) } diff --git a/auth/auth_test.go b/auth/auth_test.go index c196f74f4..7f3ec90aa 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -27,3 +27,22 @@ type testAuthSuite struct { func TestT(t *testing.T) { TestingT(t) } + +func (s *testAuthSuite) TestEscapeAccountName(c *C) { + c.Assert(EscapeAccountName(""), Equals, "''") + c.Assert(EscapeAccountName("User"), Equals, "'User'") + c.Assert(EscapeAccountName("User's"), Equals, "'User''s'") + c.Assert(EscapeAccountName("User is me"), Equals, "'User is me'") + c.Assert(EscapeAccountName(`u'v"w\'x\"y@z`+"`a"+`\b\\c`), Equals, "'u''v\"w\\''x\\\"y@z`a\\b\\\\c'") // u'v"\'x\"y@z`a\b\\c -> 'u''v"\''x\"y@z`a\b\\c' + c.Assert(EscapeAccountName("u'v\"w\\'x\\\"y@z`a\\b\\\\c"), Equals, `'u''v"w\''x\"y@z`+"`"+`a\b\\c'`) // u'v"\'x\"y@z`a\b\\c -> 'u''v"\''x\"y@z`a\b\\c' + u := UserIdentity{Username: "U & I @ Party", Hostname: "10.%", CurrentUser: false, AuthUsername: "root's friend", AuthHostname: "server"} + c.Assert(u.String(), Equals, "'U & I @ Party'@'10.%'") + c.Assert(u.AuthIdentityString(), Equals, "'root''s friend'@'server'") + u = UserIdentity{Username: "", Hostname: "", CurrentUser: false, AuthUsername: "ceo", AuthHostname: "%"} + c.Assert(u.String(), Equals, "''@''") + c.Assert(u.AuthIdentityString(), Equals, "'ceo'@'%'") + var uNil *UserIdentity = nil + c.Assert(uNil.String(), Equals, "") + r := RoleIdentity{Username: "Admin", Hostname: "192.168.%"} + c.Assert(r.String(), Equals, "'Admin'@'192.168.%'") +}