mirror of
				https://github.com/matrix-org/synapse.git
				synced 2025-11-04 02:01:03 +01:00 
			
		
		
		
	Merge pull request #2727 from matrix-org/rav/refactor_ui_auth_return
Refactor UI auth implementation
This commit is contained in:
		
						commit
						aa6ecf0984
					
				@ -140,6 +140,22 @@ class RegistrationError(SynapseError):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InteractiveAuthIncompleteError(Exception):
 | 
			
		||||
    """An error raised when UI auth is not yet complete
 | 
			
		||||
 | 
			
		||||
    (This indicates we should return a 401 with 'result' as the body)
 | 
			
		||||
 | 
			
		||||
    Attributes:
 | 
			
		||||
        result (dict): the server response to the request, which should be
 | 
			
		||||
            passed back to the client
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, result):
 | 
			
		||||
        super(InteractiveAuthIncompleteError, self).__init__(
 | 
			
		||||
            "Interactive auth not yet complete",
 | 
			
		||||
        )
 | 
			
		||||
        self.result = result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UnrecognizedRequestError(SynapseError):
 | 
			
		||||
    """An error indicating we don't understand the request you're trying to make"""
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
 | 
			
		||||
@ -17,7 +17,10 @@ from twisted.internet import defer
 | 
			
		||||
 | 
			
		||||
from ._base import BaseHandler
 | 
			
		||||
from synapse.api.constants import LoginType
 | 
			
		||||
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
 | 
			
		||||
from synapse.api.errors import (
 | 
			
		||||
    AuthError, Codes, InteractiveAuthIncompleteError, LoginError, StoreError,
 | 
			
		||||
    SynapseError,
 | 
			
		||||
)
 | 
			
		||||
from synapse.module_api import ModuleApi
 | 
			
		||||
from synapse.types import UserID
 | 
			
		||||
from synapse.util.async import run_on_reactor
 | 
			
		||||
@ -95,26 +98,36 @@ class AuthHandler(BaseHandler):
 | 
			
		||||
        session with a map, which maps each auth-type (str) to the relevant
 | 
			
		||||
        identity authenticated by that auth-type (mostly str, but for captcha, bool).
 | 
			
		||||
 | 
			
		||||
        If no auth flows have been completed successfully, raises an
 | 
			
		||||
        InteractiveAuthIncompleteError. To handle this, you can use
 | 
			
		||||
        synapse.rest.client.v2_alpha._base.interactive_auth_handler as a
 | 
			
		||||
        decorator.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            flows (list): A list of login flows. Each flow is an ordered list of
 | 
			
		||||
                          strings representing auth-types. At least one full
 | 
			
		||||
                          flow must be completed in order for auth to be successful.
 | 
			
		||||
 | 
			
		||||
            clientdict: The dictionary from the client root level, not the
 | 
			
		||||
                        'auth' key: this method prompts for auth if none is sent.
 | 
			
		||||
 | 
			
		||||
            clientip (str): The IP address of the client.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            A tuple of (authed, dict, dict, session_id) where authed is true if
 | 
			
		||||
            the client has successfully completed an auth flow. If it is true
 | 
			
		||||
            the first dict contains the authenticated credentials of each stage.
 | 
			
		||||
            defer.Deferred[dict, dict, str]: a deferred tuple of
 | 
			
		||||
                (creds, params, session_id).
 | 
			
		||||
 | 
			
		||||
            If authed is false, the first dictionary is the server response to
 | 
			
		||||
            the login request and should be passed back to the client.
 | 
			
		||||
                'creds' contains the authenticated credentials of each stage.
 | 
			
		||||
 | 
			
		||||
            In either case, the second dict contains the parameters for this
 | 
			
		||||
            request (which may have been given only in a previous call).
 | 
			
		||||
                'params' contains the parameters for this request (which may
 | 
			
		||||
                have been given only in a previous call).
 | 
			
		||||
 | 
			
		||||
            session_id is the ID of this session, either passed in by the client
 | 
			
		||||
            or assigned by the call to check_auth
 | 
			
		||||
                'session_id' is the ID of this session, either passed in by the
 | 
			
		||||
                client or assigned by this call
 | 
			
		||||
 | 
			
		||||
        Raises:
 | 
			
		||||
            InteractiveAuthIncompleteError if the client has not yet completed
 | 
			
		||||
                all the stages in any of the permitted flows.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        authdict = None
 | 
			
		||||
@ -142,11 +155,8 @@ class AuthHandler(BaseHandler):
 | 
			
		||||
            clientdict = session['clientdict']
 | 
			
		||||
 | 
			
		||||
        if not authdict:
 | 
			
		||||
            defer.returnValue(
 | 
			
		||||
                (
 | 
			
		||||
                    False, self._auth_dict_for_flows(flows, session),
 | 
			
		||||
                    clientdict, session['id']
 | 
			
		||||
                )
 | 
			
		||||
            raise InteractiveAuthIncompleteError(
 | 
			
		||||
                self._auth_dict_for_flows(flows, session),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if 'creds' not in session:
 | 
			
		||||
@ -190,12 +200,14 @@ class AuthHandler(BaseHandler):
 | 
			
		||||
                    "Auth completed with creds: %r. Client dict has keys: %r",
 | 
			
		||||
                    creds, clientdict.keys()
 | 
			
		||||
                )
 | 
			
		||||
                defer.returnValue((True, creds, clientdict, session['id']))
 | 
			
		||||
                defer.returnValue((creds, clientdict, session['id']))
 | 
			
		||||
 | 
			
		||||
        ret = self._auth_dict_for_flows(flows, session)
 | 
			
		||||
        ret['completed'] = creds.keys()
 | 
			
		||||
        ret.update(errordict)
 | 
			
		||||
        defer.returnValue((False, ret, clientdict, session['id']))
 | 
			
		||||
        raise InteractiveAuthIncompleteError(
 | 
			
		||||
            ret,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def add_oob_auth(self, stagetype, authdict, clientip):
 | 
			
		||||
 | 
			
		||||
@ -15,12 +15,13 @@
 | 
			
		||||
 | 
			
		||||
"""This module contains base REST classes for constructing client v1 servlets.
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
 | 
			
		||||
import logging
 | 
			
		||||
import re
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
from twisted.internet import defer
 | 
			
		||||
 | 
			
		||||
from synapse.api.errors import InteractiveAuthIncompleteError
 | 
			
		||||
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
@ -57,3 +58,37 @@ def set_timeline_upper_limit(filter_json, filter_timeline_limit):
 | 
			
		||||
        filter_json['room']['timeline']["limit"] = min(
 | 
			
		||||
            filter_json['room']['timeline']['limit'],
 | 
			
		||||
            filter_timeline_limit)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def interactive_auth_handler(orig):
 | 
			
		||||
    """Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
 | 
			
		||||
 | 
			
		||||
    Takes a on_POST method which returns a deferred (errcode, body) response
 | 
			
		||||
    and adds exception handling to turn a InteractiveAuthIncompleteError into
 | 
			
		||||
    a 401 response.
 | 
			
		||||
 | 
			
		||||
    Normal usage is:
 | 
			
		||||
 | 
			
		||||
    @interactive_auth_handler
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_POST(self, request):
 | 
			
		||||
        # ...
 | 
			
		||||
        yield self.auth_handler.check_auth
 | 
			
		||||
            """
 | 
			
		||||
    def wrapped(*args, **kwargs):
 | 
			
		||||
        res = defer.maybeDeferred(orig, *args, **kwargs)
 | 
			
		||||
        res.addErrback(_catch_incomplete_interactive_auth)
 | 
			
		||||
        return res
 | 
			
		||||
    return wrapped
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _catch_incomplete_interactive_auth(f):
 | 
			
		||||
    """helper for interactive_auth_handler
 | 
			
		||||
 | 
			
		||||
    Catches InteractiveAuthIncompleteErrors and turns them into 401 responses
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        f (failure.Failure):
 | 
			
		||||
    """
 | 
			
		||||
    f.trap(InteractiveAuthIncompleteError)
 | 
			
		||||
    return 401, f.value.result
 | 
			
		||||
 | 
			
		||||
@ -26,7 +26,7 @@ from synapse.http.servlet import (
 | 
			
		||||
)
 | 
			
		||||
from synapse.util.async import run_on_reactor
 | 
			
		||||
from synapse.util.msisdn import phone_number_to_msisdn
 | 
			
		||||
from ._base import client_v2_patterns
 | 
			
		||||
from ._base import client_v2_patterns, interactive_auth_handler
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
@ -100,21 +100,19 @@ class PasswordRestServlet(RestServlet):
 | 
			
		||||
        self.datastore = self.hs.get_datastore()
 | 
			
		||||
        self._set_password_handler = hs.get_set_password_handler()
 | 
			
		||||
 | 
			
		||||
    @interactive_auth_handler
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_POST(self, request):
 | 
			
		||||
        yield run_on_reactor()
 | 
			
		||||
 | 
			
		||||
        body = parse_json_object_from_request(request)
 | 
			
		||||
 | 
			
		||||
        authed, result, params, _ = yield self.auth_handler.check_auth([
 | 
			
		||||
        result, params, _ = yield self.auth_handler.check_auth([
 | 
			
		||||
            [LoginType.PASSWORD],
 | 
			
		||||
            [LoginType.EMAIL_IDENTITY],
 | 
			
		||||
            [LoginType.MSISDN],
 | 
			
		||||
        ], body, self.hs.get_ip_from_request(request))
 | 
			
		||||
 | 
			
		||||
        if not authed:
 | 
			
		||||
            defer.returnValue((401, result))
 | 
			
		||||
 | 
			
		||||
        user_id = None
 | 
			
		||||
        requester = None
 | 
			
		||||
 | 
			
		||||
@ -168,6 +166,7 @@ class DeactivateAccountRestServlet(RestServlet):
 | 
			
		||||
        self.auth_handler = hs.get_auth_handler()
 | 
			
		||||
        self._deactivate_account_handler = hs.get_deactivate_account_handler()
 | 
			
		||||
 | 
			
		||||
    @interactive_auth_handler
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_POST(self, request):
 | 
			
		||||
        body = parse_json_object_from_request(request)
 | 
			
		||||
@ -186,13 +185,10 @@ class DeactivateAccountRestServlet(RestServlet):
 | 
			
		||||
            )
 | 
			
		||||
            defer.returnValue((200, {}))
 | 
			
		||||
 | 
			
		||||
        authed, result, params, _ = yield self.auth_handler.check_auth([
 | 
			
		||||
        result, params, _ = yield self.auth_handler.check_auth([
 | 
			
		||||
            [LoginType.PASSWORD],
 | 
			
		||||
        ], body, self.hs.get_ip_from_request(request))
 | 
			
		||||
 | 
			
		||||
        if not authed:
 | 
			
		||||
            defer.returnValue((401, result))
 | 
			
		||||
 | 
			
		||||
        if LoginType.PASSWORD in result:
 | 
			
		||||
            user_id = result[LoginType.PASSWORD]
 | 
			
		||||
            # if using password, they should also be logged in
 | 
			
		||||
 | 
			
		||||
@ -19,7 +19,7 @@ from twisted.internet import defer
 | 
			
		||||
 | 
			
		||||
from synapse.api import constants, errors
 | 
			
		||||
from synapse.http import servlet
 | 
			
		||||
from ._base import client_v2_patterns
 | 
			
		||||
from ._base import client_v2_patterns, interactive_auth_handler
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
@ -60,6 +60,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
 | 
			
		||||
        self.device_handler = hs.get_device_handler()
 | 
			
		||||
        self.auth_handler = hs.get_auth_handler()
 | 
			
		||||
 | 
			
		||||
    @interactive_auth_handler
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_POST(self, request):
 | 
			
		||||
        try:
 | 
			
		||||
@ -77,13 +78,10 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
 | 
			
		||||
                400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        authed, result, params, _ = yield self.auth_handler.check_auth([
 | 
			
		||||
        result, params, _ = yield self.auth_handler.check_auth([
 | 
			
		||||
            [constants.LoginType.PASSWORD],
 | 
			
		||||
        ], body, self.hs.get_ip_from_request(request))
 | 
			
		||||
 | 
			
		||||
        if not authed:
 | 
			
		||||
            defer.returnValue((401, result))
 | 
			
		||||
 | 
			
		||||
        requester = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        yield self.device_handler.delete_devices(
 | 
			
		||||
            requester.user.to_string(),
 | 
			
		||||
@ -115,6 +113,7 @@ class DeviceRestServlet(servlet.RestServlet):
 | 
			
		||||
        )
 | 
			
		||||
        defer.returnValue((200, device))
 | 
			
		||||
 | 
			
		||||
    @interactive_auth_handler
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_DELETE(self, request, device_id):
 | 
			
		||||
        requester = yield self.auth.get_user_by_req(request)
 | 
			
		||||
@ -130,13 +129,10 @@ class DeviceRestServlet(servlet.RestServlet):
 | 
			
		||||
            else:
 | 
			
		||||
                raise
 | 
			
		||||
 | 
			
		||||
        authed, result, params, _ = yield self.auth_handler.check_auth([
 | 
			
		||||
        result, params, _ = yield self.auth_handler.check_auth([
 | 
			
		||||
            [constants.LoginType.PASSWORD],
 | 
			
		||||
        ], body, self.hs.get_ip_from_request(request))
 | 
			
		||||
 | 
			
		||||
        if not authed:
 | 
			
		||||
            defer.returnValue((401, result))
 | 
			
		||||
 | 
			
		||||
        # check that the UI auth matched the access token
 | 
			
		||||
        user_id = result[constants.LoginType.PASSWORD]
 | 
			
		||||
        if user_id != requester.user.to_string():
 | 
			
		||||
 | 
			
		||||
@ -27,7 +27,7 @@ from synapse.http.servlet import (
 | 
			
		||||
)
 | 
			
		||||
from synapse.util.msisdn import phone_number_to_msisdn
 | 
			
		||||
 | 
			
		||||
from ._base import client_v2_patterns
 | 
			
		||||
from ._base import client_v2_patterns, interactive_auth_handler
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
import hmac
 | 
			
		||||
@ -176,6 +176,7 @@ class RegisterRestServlet(RestServlet):
 | 
			
		||||
        self.device_handler = hs.get_device_handler()
 | 
			
		||||
        self.macaroon_gen = hs.get_macaroon_generator()
 | 
			
		||||
 | 
			
		||||
    @interactive_auth_handler
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_POST(self, request):
 | 
			
		||||
        yield run_on_reactor()
 | 
			
		||||
@ -325,14 +326,10 @@ class RegisterRestServlet(RestServlet):
 | 
			
		||||
                    [LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
 | 
			
		||||
                ])
 | 
			
		||||
 | 
			
		||||
        authed, auth_result, params, session_id = yield self.auth_handler.check_auth(
 | 
			
		||||
        auth_result, params, session_id = yield self.auth_handler.check_auth(
 | 
			
		||||
            flows, body, self.hs.get_ip_from_request(request)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if not authed:
 | 
			
		||||
            defer.returnValue((401, auth_result))
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        if registered_user_id is not None:
 | 
			
		||||
            logger.info(
 | 
			
		||||
                "Already registered user ID %r for this session",
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,7 @@
 | 
			
		||||
from twisted.python import failure
 | 
			
		||||
 | 
			
		||||
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
 | 
			
		||||
from synapse.api.errors import SynapseError
 | 
			
		||||
from synapse.api.errors import SynapseError, InteractiveAuthIncompleteError
 | 
			
		||||
from twisted.internet import defer
 | 
			
		||||
from mock import Mock
 | 
			
		||||
from tests import unittest
 | 
			
		||||
@ -24,7 +26,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
 | 
			
		||||
            side_effect=lambda x: self.appservice)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.auth_result = (False, None, None, None)
 | 
			
		||||
        self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None))
 | 
			
		||||
        self.auth_handler = Mock(
 | 
			
		||||
            check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
 | 
			
		||||
            get_session_data=Mock(return_value=None)
 | 
			
		||||
@ -86,6 +88,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
 | 
			
		||||
        self.request.args = {
 | 
			
		||||
            "access_token": "i_am_an_app_service"
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        self.request_data = json.dumps({
 | 
			
		||||
            "username": "kermit"
 | 
			
		||||
        })
 | 
			
		||||
@ -120,7 +123,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
 | 
			
		||||
            "device_id": device_id,
 | 
			
		||||
        })
 | 
			
		||||
        self.registration_handler.check_username = Mock(return_value=True)
 | 
			
		||||
        self.auth_result = (True, None, {
 | 
			
		||||
        self.auth_result = (None, {
 | 
			
		||||
            "username": "kermit",
 | 
			
		||||
            "password": "monkey"
 | 
			
		||||
        }, None)
 | 
			
		||||
@ -150,7 +153,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
 | 
			
		||||
            "password": "monkey"
 | 
			
		||||
        })
 | 
			
		||||
        self.registration_handler.check_username = Mock(return_value=True)
 | 
			
		||||
        self.auth_result = (True, None, {
 | 
			
		||||
        self.auth_result = (None, {
 | 
			
		||||
            "username": "kermit",
 | 
			
		||||
            "password": "monkey"
 | 
			
		||||
        }, None)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user