Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/4264.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix CAS login when username is not valid in an MXID
13 changes: 11 additions & 2 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,10 +563,10 @@ def check_user_exists(self, user_id):
insensitively, but return None if there are multiple inexact matches.

Args:
(str) user_id: complete @user:id
(unicode|bytes) user_id: complete @user:id

Returns:
defer.Deferred: (str) canonical_user_id, or None if zero or
defer.Deferred: (unicode) canonical_user_id, or None if zero or
multiple matches
"""
res = yield self._find_user_id_and_pwd_hash(user_id)
Expand Down Expand Up @@ -954,6 +954,15 @@ def generate_access_token(self, user_id, extra_caveats=None):
return macaroon.serialize()

def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
"""

Args:
user_id (unicode):
duration_in_ms (int):

Returns:
unicode
"""
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec()
Expand Down
105 changes: 76 additions & 29 deletions synapse/rest/client/v1/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@

from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.http.server import finish_request
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID
from synapse.http.servlet import (
RestServlet,
parse_json_object_from_request,
parse_string,
)
from synapse.types import UserID, map_username_to_mxid_localpart
from synapse.util.msisdn import phone_number_to_msisdn

from .base import ClientV1RestServlet, client_path_patterns
Expand Down Expand Up @@ -421,17 +425,15 @@ def __init__(self, hs):
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
self.cas_required_attributes = hs.config.cas_required_attributes
self.auth_handler = hs.get_auth_handler()
self.handlers = hs.get_handlers()
self.macaroon_gen = hs.get_macaroon_generator()
self._sso_auth_handler = SSOAuthHandler(hs)

@defer.inlineCallbacks
def on_GET(self, request):
client_redirect_url = request.args[b"redirectUrl"][0]
client_redirect_url = parse_string(request, "redirectUrl", required=True)
http_client = self.hs.get_simple_http_client()
uri = self.cas_server_url + "/proxyValidate"
args = {
"ticket": request.args[b"ticket"][0].decode('ascii'),
"ticket": parse_string(request, "ticket", required=True),
"service": self.cas_service_url
}
try:
Expand All @@ -443,7 +445,6 @@ def on_GET(self, request):
result = yield self.handle_cas_response(request, body, client_redirect_url)
defer.returnValue(result)

@defer.inlineCallbacks
def handle_cas_response(self, request, cas_response_body, client_redirect_url):
user, attributes = self.parse_cas_response(cas_response_body)

Expand All @@ -459,28 +460,9 @@ def handle_cas_response(self, request, cas_response_body, client_redirect_url):
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)

user_id = UserID(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler
registered_user_id = yield auth_handler.check_user_exists(user_id)
if not registered_user_id:
registered_user_id, _ = (
yield self.handlers.registration_handler.register(localpart=user)
)

login_token = self.macaroon_gen.generate_short_term_login_token(
registered_user_id
return self._sso_auth_handler.on_successful_auth(
user, request, client_redirect_url,
)
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
login_token)
request.redirect(redirect_url)
finish_request(request)

def add_login_token_to_redirect_url(self, url, token):
url_parts = list(urllib.parse.urlparse(url))
query = dict(urllib.parse.parse_qsl(url_parts[4]))
query.update({"loginToken": token})
url_parts[4] = urllib.parse.urlencode(query).encode('ascii')
return urllib.parse.urlunparse(url_parts)

def parse_cas_response(self, cas_response_body):
user = None
Expand Down Expand Up @@ -515,6 +497,71 @@ def parse_cas_response(self, cas_response_body):
return user, attributes


class SSOAuthHandler(object):
"""
Utility class for Resources and Servlets which handle the response from a SSO
service

Args:
hs (synapse.server.HomeServer)
"""
def __init__(self, hs):
self._hostname = hs.hostname
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_handlers().registration_handler
self._macaroon_gen = hs.get_macaroon_generator()

@defer.inlineCallbacks
def on_successful_auth(
self, username, request, client_redirect_url,
):
"""Called once the user has successfully authenticated with the SSO.

Registers the user if necessary, and then returns a redirect (with
a login token) to the client.

Args:
username (unicode|bytes): the remote user id. We'll map this onto
something sane for a MXID localpath.

request (SynapseRequest): the incoming request from the browser. We'll
respond to it with a redirect.

client_redirect_url (unicode): the redirect_url the client gave us when
it first started the process.

Returns:
Deferred[none]: Completes once we have handled the request.
"""
localpart = map_username_to_mxid_localpart(username)
user_id = UserID(localpart, self._hostname).to_string()
registered_user_id = yield self._auth_handler.check_user_exists(user_id)
if not registered_user_id:
registered_user_id, _ = (
yield self._registration_handler.register(
localpart=localpart,
generate_token=False,
)
)

login_token = self._macaroon_gen.generate_short_term_login_token(
registered_user_id
)
redirect_url = self._add_login_token_to_redirect_url(
client_redirect_url, login_token
)
request.redirect(redirect_url)
finish_request(request)

@staticmethod
def _add_login_token_to_redirect_url(url, token):
url_parts = list(urllib.parse.urlparse(url))
query = dict(urllib.parse.parse_qsl(url_parts[4]))
query.update({"loginToken": token})
url_parts[4] = urllib.parse.urlencode(query)
return urllib.parse.urlunparse(url_parts)


def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server)
if hs.config.saml2_enabled:
Expand Down
66 changes: 66 additions & 0 deletions synapse/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import string
from collections import namedtuple

Expand Down Expand Up @@ -228,6 +229,71 @@ def contains_invalid_mxid_characters(localpart):
return any(c not in mxid_localpart_allowed_characters for c in localpart)


UPPER_CASE_PATTERN = re.compile(b"[A-Z_]")

# the following is a pattern which matches '=', and bytes which are not allowed in a mxid
# localpart.
#
# It works by:
# * building a string containing the allowed characters (excluding '=')
# * escaping every special character with a backslash (to stop '-' being interpreted as a
# range operator)
# * wrapping it in a '[^...]' regex
# * converting the whole lot to a 'bytes' sequence, so that we can use it to match
# bytes rather than strings
#
NON_MXID_CHARACTER_PATTERN = re.compile(
("[^%s]" % (
re.escape("".join(mxid_localpart_allowed_characters - {"="}),),
)).encode("ascii"),
)


def map_username_to_mxid_localpart(username, case_sensitive=False):
"""Map a username onto a string suitable for a MXID

This follows the algorithm laid out at
https://matrix.org/docs/spec/appendices.html#mapping-from-other-character-sets.

Args:
username (unicode|bytes): username to be mapped
case_sensitive (bool): true if TEST and test should be mapped
onto different mxids

Returns:
unicode: string suitable for a mxid localpart
"""
if not isinstance(username, bytes):
username = username.encode('utf-8')

# first we sort out upper-case characters
if case_sensitive:
def f1(m):
return b"_" + m.group().lower()

username = UPPER_CASE_PATTERN.sub(f1, username)
else:
username = username.lower()

# then we sort out non-ascii characters
def f2(m):
g = m.group()[0]
if isinstance(g, str):
# on python 2, we need to do a ord(). On python 3, the
# byte itself will do.
g = ord(g)
return b"=%02x" % (g,)

username = NON_MXID_CHARACTER_PATTERN.sub(f2, username)

# we also do the =-escaping to mxids starting with an underscore.
username = re.sub(b'^_', b'=5f', username)

# we should now only have ascii bytes left, so can decode back to a
# unicode.
return username.decode('ascii')


class StreamToken(
namedtuple("Token", (
"room_key",
Expand Down
31 changes: 30 additions & 1 deletion tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

from synapse.api.errors import SynapseError
from synapse.types import GroupID, RoomAlias, UserID
from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart

from tests import unittest
from tests.utils import TestHomeServer
Expand Down Expand Up @@ -79,3 +79,32 @@ def test_validate(self):
except SynapseError as exc:
self.assertEqual(400, exc.code)
self.assertEqual("M_UNKNOWN", exc.errcode)


class MapUsernameTestCase(unittest.TestCase):
def testPassThrough(self):
self.assertEqual(map_username_to_mxid_localpart("test1234"), "test1234")

def testUpperCase(self):
self.assertEqual(map_username_to_mxid_localpart("tEST_1234"), "test_1234")
self.assertEqual(
map_username_to_mxid_localpart("tEST_1234", case_sensitive=True),
"t_e_s_t__1234",
)

def testSymbols(self):
self.assertEqual(
map_username_to_mxid_localpart("test=$?_1234"),
"test=3d=24=3f_1234",
)

def testLeadingUnderscore(self):
self.assertEqual(map_username_to_mxid_localpart("_test_1234"), "=5ftest_1234")

def testNonAscii(self):
# this should work with either a unicode or a bytes
self.assertEqual(map_username_to_mxid_localpart(u'têst'), "t=c3=aast")
self.assertEqual(
map_username_to_mxid_localpart(u'têst'.encode('utf-8')),
"t=c3=aast",
)