mirror of
				https://github.com/matrix-org/synapse.git
				synced 2025-10-31 16:21:56 +01:00 
			
		
		
		
	Merge branch 'develop' of github.com:matrix-org/synapse into release-v0.19.0
This commit is contained in:
		
						commit
						f8c407a13b
					
				| @ -65,6 +65,7 @@ class AuthHandler(BaseHandler): | |||||||
| 
 | 
 | ||||||
|         self.hs = hs  # FIXME better possibility to access registrationHandler later? |         self.hs = hs  # FIXME better possibility to access registrationHandler later? | ||||||
|         self.device_handler = hs.get_device_handler() |         self.device_handler = hs.get_device_handler() | ||||||
|  |         self.macaroon_gen = hs.get_macaroon_generator() | ||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def check_auth(self, flows, clientdict, clientip): |     def check_auth(self, flows, clientdict, clientip): | ||||||
| @ -529,37 +530,11 @@ class AuthHandler(BaseHandler): | |||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def issue_access_token(self, user_id, device_id=None): |     def issue_access_token(self, user_id, device_id=None): | ||||||
|         access_token = self.generate_access_token(user_id) |         access_token = self.macaroon_gen.generate_access_token(user_id) | ||||||
|         yield self.store.add_access_token_to_user(user_id, access_token, |         yield self.store.add_access_token_to_user(user_id, access_token, | ||||||
|                                                   device_id) |                                                   device_id) | ||||||
|         defer.returnValue(access_token) |         defer.returnValue(access_token) | ||||||
| 
 | 
 | ||||||
|     def generate_access_token(self, user_id, extra_caveats=None): |  | ||||||
|         extra_caveats = extra_caveats or [] |  | ||||||
|         macaroon = self._generate_base_macaroon(user_id) |  | ||||||
|         macaroon.add_first_party_caveat("type = access") |  | ||||||
|         # Include a nonce, to make sure that each login gets a different |  | ||||||
|         # access token. |  | ||||||
|         macaroon.add_first_party_caveat("nonce = %s" % ( |  | ||||||
|             stringutils.random_string_with_symbols(16), |  | ||||||
|         )) |  | ||||||
|         for caveat in extra_caveats: |  | ||||||
|             macaroon.add_first_party_caveat(caveat) |  | ||||||
|         return macaroon.serialize() |  | ||||||
| 
 |  | ||||||
|     def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)): |  | ||||||
|         macaroon = self._generate_base_macaroon(user_id) |  | ||||||
|         macaroon.add_first_party_caveat("type = login") |  | ||||||
|         now = self.hs.get_clock().time_msec() |  | ||||||
|         expiry = now + duration_in_ms |  | ||||||
|         macaroon.add_first_party_caveat("time < %d" % (expiry,)) |  | ||||||
|         return macaroon.serialize() |  | ||||||
| 
 |  | ||||||
|     def generate_delete_pusher_token(self, user_id): |  | ||||||
|         macaroon = self._generate_base_macaroon(user_id) |  | ||||||
|         macaroon.add_first_party_caveat("type = delete_pusher") |  | ||||||
|         return macaroon.serialize() |  | ||||||
| 
 |  | ||||||
|     def validate_short_term_login_token_and_get_user_id(self, login_token): |     def validate_short_term_login_token_and_get_user_id(self, login_token): | ||||||
|         auth_api = self.hs.get_auth() |         auth_api = self.hs.get_auth() | ||||||
|         try: |         try: | ||||||
| @ -570,15 +545,6 @@ class AuthHandler(BaseHandler): | |||||||
|         except Exception: |         except Exception: | ||||||
|             raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) |             raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) | ||||||
| 
 | 
 | ||||||
|     def _generate_base_macaroon(self, user_id): |  | ||||||
|         macaroon = pymacaroons.Macaroon( |  | ||||||
|             location=self.hs.config.server_name, |  | ||||||
|             identifier="key", |  | ||||||
|             key=self.hs.config.macaroon_secret_key) |  | ||||||
|         macaroon.add_first_party_caveat("gen = 1") |  | ||||||
|         macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) |  | ||||||
|         return macaroon |  | ||||||
| 
 |  | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def set_password(self, user_id, newpassword, requester=None): |     def set_password(self, user_id, newpassword, requester=None): | ||||||
|         password_hash = self.hash(newpassword) |         password_hash = self.hash(newpassword) | ||||||
| @ -673,6 +639,48 @@ class AuthHandler(BaseHandler): | |||||||
|             return False |             return False | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class MacaroonGeneartor(object): | ||||||
|  |     def __init__(self, hs): | ||||||
|  |         self.clock = hs.get_clock() | ||||||
|  |         self.server_name = hs.config.server_name | ||||||
|  |         self.macaroon_secret_key = hs.config.macaroon_secret_key | ||||||
|  | 
 | ||||||
|  |     def generate_access_token(self, user_id, extra_caveats=None): | ||||||
|  |         extra_caveats = extra_caveats or [] | ||||||
|  |         macaroon = self._generate_base_macaroon(user_id) | ||||||
|  |         macaroon.add_first_party_caveat("type = access") | ||||||
|  |         # Include a nonce, to make sure that each login gets a different | ||||||
|  |         # access token. | ||||||
|  |         macaroon.add_first_party_caveat("nonce = %s" % ( | ||||||
|  |             stringutils.random_string_with_symbols(16), | ||||||
|  |         )) | ||||||
|  |         for caveat in extra_caveats: | ||||||
|  |             macaroon.add_first_party_caveat(caveat) | ||||||
|  |         return macaroon.serialize() | ||||||
|  | 
 | ||||||
|  |     def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)): | ||||||
|  |         macaroon = self._generate_base_macaroon(user_id) | ||||||
|  |         macaroon.add_first_party_caveat("type = login") | ||||||
|  |         now = self.clock.time_msec() | ||||||
|  |         expiry = now + duration_in_ms | ||||||
|  |         macaroon.add_first_party_caveat("time < %d" % (expiry,)) | ||||||
|  |         return macaroon.serialize() | ||||||
|  | 
 | ||||||
|  |     def generate_delete_pusher_token(self, user_id): | ||||||
|  |         macaroon = self._generate_base_macaroon(user_id) | ||||||
|  |         macaroon.add_first_party_caveat("type = delete_pusher") | ||||||
|  |         return macaroon.serialize() | ||||||
|  | 
 | ||||||
|  |     def _generate_base_macaroon(self, user_id): | ||||||
|  |         macaroon = pymacaroons.Macaroon( | ||||||
|  |             location=self.server_name, | ||||||
|  |             identifier="key", | ||||||
|  |             key=self.macaroon_secret_key) | ||||||
|  |         macaroon.add_first_party_caveat("gen = 1") | ||||||
|  |         macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) | ||||||
|  |         return macaroon | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class _AccountHandler(object): | class _AccountHandler(object): | ||||||
|     """A proxy object that gets passed to password auth providers so they |     """A proxy object that gets passed to password auth providers so they | ||||||
|     can register new users etc if necessary. |     can register new users etc if necessary. | ||||||
|  | |||||||
| @ -17,7 +17,7 @@ from synapse.api import errors | |||||||
| from synapse.api.constants import EventTypes | from synapse.api.constants import EventTypes | ||||||
| from synapse.util import stringutils | from synapse.util import stringutils | ||||||
| from synapse.util.async import Linearizer | from synapse.util.async import Linearizer | ||||||
| from synapse.types import get_domain_from_id | from synapse.types import get_domain_from_id, RoomStreamToken | ||||||
| from twisted.internet import defer | from twisted.internet import defer | ||||||
| from ._base import BaseHandler | from ._base import BaseHandler | ||||||
| 
 | 
 | ||||||
| @ -198,20 +198,22 @@ class DeviceHandler(BaseHandler): | |||||||
|         """Notify that a user's device(s) has changed. Pokes the notifier, and |         """Notify that a user's device(s) has changed. Pokes the notifier, and | ||||||
|         remote servers if the user is local. |         remote servers if the user is local. | ||||||
|         """ |         """ | ||||||
|         rooms = yield self.store.get_rooms_for_user(user_id) |         users_who_share_room = yield self.store.get_users_who_share_room_with_user( | ||||||
|         room_ids = [r.room_id for r in rooms] |             user_id | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         hosts = set() |         hosts = set() | ||||||
|         if self.hs.is_mine_id(user_id): |         if self.hs.is_mine_id(user_id): | ||||||
|             for room_id in room_ids: |             hosts.update(get_domain_from_id(u) for u in users_who_share_room) | ||||||
|                 users = yield self.store.get_users_in_room(room_id) |  | ||||||
|                 hosts.update(get_domain_from_id(u) for u in users) |  | ||||||
|             hosts.discard(self.server_name) |             hosts.discard(self.server_name) | ||||||
| 
 | 
 | ||||||
|         position = yield self.store.add_device_change_to_streams( |         position = yield self.store.add_device_change_to_streams( | ||||||
|             user_id, device_ids, list(hosts) |             user_id, device_ids, list(hosts) | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  |         rooms = yield self.store.get_rooms_for_user(user_id) | ||||||
|  |         room_ids = [r.room_id for r in rooms] | ||||||
|  | 
 | ||||||
|         yield self.notifier.on_new_event( |         yield self.notifier.on_new_event( | ||||||
|             "device_list_key", position, rooms=room_ids, |             "device_list_key", position, rooms=room_ids, | ||||||
|         ) |         ) | ||||||
| @ -243,15 +245,15 @@ class DeviceHandler(BaseHandler): | |||||||
| 
 | 
 | ||||||
|         possibly_changed = set(changed) |         possibly_changed = set(changed) | ||||||
|         for room_id in rooms_changed: |         for room_id in rooms_changed: | ||||||
|             # Fetch (an approximation) of the current state at the time. |             # Fetch  the current state at the time. | ||||||
|             event_rows, token = yield self.store.get_recent_event_ids_for_room( |             stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key) | ||||||
|                 room_id, end_token=from_token.room_key, limit=1, |  | ||||||
|             ) |  | ||||||
| 
 | 
 | ||||||
|             if event_rows: |             try: | ||||||
|                 last_event_id = event_rows[-1]["event_id"] |                 event_ids = yield self.store.get_forward_extremeties_for_room( | ||||||
|                 prev_state_ids = yield self.store.get_state_ids_for_event(last_event_id) |                     room_id, stream_ordering=stream_ordering | ||||||
|             else: |                 ) | ||||||
|  |                 prev_state_ids = yield self.store.get_state_ids_for_events(event_ids) | ||||||
|  |             except: | ||||||
|                 prev_state_ids = {} |                 prev_state_ids = {} | ||||||
| 
 | 
 | ||||||
|             current_state_ids = yield self.state.get_current_state_ids(room_id) |             current_state_ids = yield self.state.get_current_state_ids(room_id) | ||||||
| @ -266,13 +268,13 @@ class DeviceHandler(BaseHandler): | |||||||
|                     if not prev_event_id or prev_event_id != event_id: |                     if not prev_event_id or prev_event_id != event_id: | ||||||
|                         possibly_changed.add(state_key) |                         possibly_changed.add(state_key) | ||||||
| 
 | 
 | ||||||
|         user_ids_changed = set() |         users_who_share_room = yield self.store.get_users_who_share_room_with_user( | ||||||
|         for other_user_id in possibly_changed: |             user_id | ||||||
|             other_rooms = yield self.store.get_rooms_for_user(other_user_id) |         ) | ||||||
|             if room_ids.intersection(e.room_id for e in other_rooms): |  | ||||||
|                 user_ids_changed.add(other_user_id) |  | ||||||
| 
 | 
 | ||||||
|         defer.returnValue(user_ids_changed) |         # Take the intersection of the users whose devices may have changed | ||||||
|  |         # and those that actually still share a room with the user | ||||||
|  |         defer.returnValue(users_who_share_room & possibly_changed) | ||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def _incoming_device_list_update(self, origin, edu_content): |     def _incoming_device_list_update(self, origin, edu_content): | ||||||
|  | |||||||
| @ -1011,7 +1011,7 @@ class PresenceEventSource(object): | |||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     @log_function |     @log_function | ||||||
|     def get_new_events(self, user, from_key, room_ids=None, include_offline=True, |     def get_new_events(self, user, from_key, room_ids=None, include_offline=True, | ||||||
|                        **kwargs): |                        explicit_room_id=None, **kwargs): | ||||||
|         # The process for getting presence events are: |         # The process for getting presence events are: | ||||||
|         #  1. Get the rooms the user is in. |         #  1. Get the rooms the user is in. | ||||||
|         #  2. Get the list of user in the rooms. |         #  2. Get the list of user in the rooms. | ||||||
| @ -1028,22 +1028,24 @@ class PresenceEventSource(object): | |||||||
|             user_id = user.to_string() |             user_id = user.to_string() | ||||||
|             if from_key is not None: |             if from_key is not None: | ||||||
|                 from_key = int(from_key) |                 from_key = int(from_key) | ||||||
|             room_ids = room_ids or [] |  | ||||||
| 
 | 
 | ||||||
|             presence = self.get_presence_handler() |             presence = self.get_presence_handler() | ||||||
|             stream_change_cache = self.store.presence_stream_cache |             stream_change_cache = self.store.presence_stream_cache | ||||||
| 
 | 
 | ||||||
|             if not room_ids: |  | ||||||
|                 rooms = yield self.store.get_rooms_for_user(user_id) |  | ||||||
|                 room_ids = set(e.room_id for e in rooms) |  | ||||||
|             else: |  | ||||||
|                 room_ids = set(room_ids) |  | ||||||
| 
 |  | ||||||
|             max_token = self.store.get_current_presence_token() |             max_token = self.store.get_current_presence_token() | ||||||
| 
 | 
 | ||||||
|             plist = yield self.store.get_presence_list_accepted(user.localpart) |             plist = yield self.store.get_presence_list_accepted(user.localpart) | ||||||
|             friends = set(row["observed_user_id"] for row in plist) |             users_interested_in = set(row["observed_user_id"] for row in plist) | ||||||
|             friends.add(user_id)  # So that we receive our own presence |             users_interested_in.add(user_id)  # So that we receive our own presence | ||||||
|  | 
 | ||||||
|  |             users_who_share_room = yield self.store.get_users_who_share_room_with_user( | ||||||
|  |                 user_id | ||||||
|  |             ) | ||||||
|  |             users_interested_in.update(users_who_share_room) | ||||||
|  | 
 | ||||||
|  |             if explicit_room_id: | ||||||
|  |                 user_ids = yield self.store.get_users_in_room(explicit_room_id) | ||||||
|  |                 users_interested_in.update(user_ids) | ||||||
| 
 | 
 | ||||||
|             user_ids_changed = set() |             user_ids_changed = set() | ||||||
|             changed = None |             changed = None | ||||||
| @ -1055,35 +1057,19 @@ class PresenceEventSource(object): | |||||||
|                 # work out if we share a room or they're in our presence list |                 # work out if we share a room or they're in our presence list | ||||||
|                 get_updates_counter.inc("stream") |                 get_updates_counter.inc("stream") | ||||||
|                 for other_user_id in changed: |                 for other_user_id in changed: | ||||||
|                     if other_user_id in friends: |                     if other_user_id in users_interested_in: | ||||||
|                         user_ids_changed.add(other_user_id) |                         user_ids_changed.add(other_user_id) | ||||||
|                         continue |  | ||||||
|                     other_rooms = yield self.store.get_rooms_for_user(other_user_id) |  | ||||||
|                     if room_ids.intersection(e.room_id for e in other_rooms): |  | ||||||
|                         user_ids_changed.add(other_user_id) |  | ||||||
|                         continue |  | ||||||
|             else: |             else: | ||||||
|                 # Too many possible updates. Find all users we can see and check |                 # Too many possible updates. Find all users we can see and check | ||||||
|                 # if any of them have changed. |                 # if any of them have changed. | ||||||
|                 get_updates_counter.inc("full") |                 get_updates_counter.inc("full") | ||||||
| 
 | 
 | ||||||
|                 user_ids_to_check = set() |  | ||||||
|                 for room_id in room_ids: |  | ||||||
|                     users = yield self.store.get_users_in_room(room_id) |  | ||||||
|                     user_ids_to_check.update(users) |  | ||||||
| 
 |  | ||||||
|                 user_ids_to_check.update(friends) |  | ||||||
| 
 |  | ||||||
|                 # Always include yourself. Only really matters for when the user is |  | ||||||
|                 # not in any rooms, but still. |  | ||||||
|                 user_ids_to_check.add(user_id) |  | ||||||
| 
 |  | ||||||
|                 if from_key: |                 if from_key: | ||||||
|                     user_ids_changed = stream_change_cache.get_entities_changed( |                     user_ids_changed = stream_change_cache.get_entities_changed( | ||||||
|                         user_ids_to_check, from_key, |                         users_interested_in, from_key, | ||||||
|                     ) |                     ) | ||||||
|                 else: |                 else: | ||||||
|                     user_ids_changed = user_ids_to_check |                     user_ids_changed = users_interested_in | ||||||
| 
 | 
 | ||||||
|             updates = yield presence.current_state_for_users(user_ids_changed) |             updates = yield presence.current_state_for_users(user_ids_changed) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -40,6 +40,8 @@ class RegistrationHandler(BaseHandler): | |||||||
| 
 | 
 | ||||||
|         self._next_generated_user_id = None |         self._next_generated_user_id = None | ||||||
| 
 | 
 | ||||||
|  |         self.macaroon_gen = hs.get_macaroon_generator() | ||||||
|  | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def check_username(self, localpart, guest_access_token=None, |     def check_username(self, localpart, guest_access_token=None, | ||||||
|                        assigned_user_id=None): |                        assigned_user_id=None): | ||||||
| @ -143,7 +145,7 @@ class RegistrationHandler(BaseHandler): | |||||||
| 
 | 
 | ||||||
|             token = None |             token = None | ||||||
|             if generate_token: |             if generate_token: | ||||||
|                 token = self.auth_handler().generate_access_token(user_id) |                 token = self.macaroon_gen.generate_access_token(user_id) | ||||||
|             yield self.store.register( |             yield self.store.register( | ||||||
|                 user_id=user_id, |                 user_id=user_id, | ||||||
|                 token=token, |                 token=token, | ||||||
| @ -167,7 +169,7 @@ class RegistrationHandler(BaseHandler): | |||||||
|                 user_id = user.to_string() |                 user_id = user.to_string() | ||||||
|                 yield self.check_user_id_not_appservice_exclusive(user_id) |                 yield self.check_user_id_not_appservice_exclusive(user_id) | ||||||
|                 if generate_token: |                 if generate_token: | ||||||
|                     token = self.auth_handler().generate_access_token(user_id) |                     token = self.macaroon_gen.generate_access_token(user_id) | ||||||
|                 try: |                 try: | ||||||
|                     yield self.store.register( |                     yield self.store.register( | ||||||
|                         user_id=user_id, |                         user_id=user_id, | ||||||
| @ -254,7 +256,7 @@ class RegistrationHandler(BaseHandler): | |||||||
|         user_id = user.to_string() |         user_id = user.to_string() | ||||||
| 
 | 
 | ||||||
|         yield self.check_user_id_not_appservice_exclusive(user_id) |         yield self.check_user_id_not_appservice_exclusive(user_id) | ||||||
|         token = self.auth_handler().generate_access_token(user_id) |         token = self.macaroon_gen.generate_access_token(user_id) | ||||||
|         try: |         try: | ||||||
|             yield self.store.register( |             yield self.store.register( | ||||||
|                 user_id=user_id, |                 user_id=user_id, | ||||||
| @ -399,7 +401,7 @@ class RegistrationHandler(BaseHandler): | |||||||
| 
 | 
 | ||||||
|         user = UserID(localpart, self.hs.hostname) |         user = UserID(localpart, self.hs.hostname) | ||||||
|         user_id = user.to_string() |         user_id = user.to_string() | ||||||
|         token = self.auth_handler().generate_access_token(user_id) |         token = self.macaroon_gen.generate_access_token(user_id) | ||||||
| 
 | 
 | ||||||
|         if need_register: |         if need_register: | ||||||
|             yield self.store.register( |             yield self.store.register( | ||||||
|  | |||||||
| @ -437,6 +437,7 @@ class RoomEventSource(object): | |||||||
|             limit, |             limit, | ||||||
|             room_ids, |             room_ids, | ||||||
|             is_guest, |             is_guest, | ||||||
|  |             explicit_room_id=None, | ||||||
|     ): |     ): | ||||||
|         # We just ignore the key for now. |         # We just ignore the key for now. | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -378,6 +378,7 @@ class Notifier(object): | |||||||
|                     limit=limit, |                     limit=limit, | ||||||
|                     is_guest=is_peeking, |                     is_guest=is_peeking, | ||||||
|                     room_ids=room_ids, |                     room_ids=room_ids, | ||||||
|  |                     explicit_room_id=explicit_room_id, | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|                 if name == "room": |                 if name == "room": | ||||||
|  | |||||||
| @ -81,7 +81,7 @@ class Mailer(object): | |||||||
|     def __init__(self, hs, app_name): |     def __init__(self, hs, app_name): | ||||||
|         self.hs = hs |         self.hs = hs | ||||||
|         self.store = self.hs.get_datastore() |         self.store = self.hs.get_datastore() | ||||||
|         self.auth_handler = self.hs.get_auth_handler() |         self.macaroon_gen = self.hs.get_macaroon_generator() | ||||||
|         self.state_handler = self.hs.get_state_handler() |         self.state_handler = self.hs.get_state_handler() | ||||||
|         loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir) |         loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir) | ||||||
|         self.app_name = app_name |         self.app_name = app_name | ||||||
| @ -466,7 +466,7 @@ class Mailer(object): | |||||||
| 
 | 
 | ||||||
|     def make_unsubscribe_link(self, user_id, app_id, email_address): |     def make_unsubscribe_link(self, user_id, app_id, email_address): | ||||||
|         params = { |         params = { | ||||||
|             "access_token": self.auth_handler.generate_delete_pusher_token(user_id), |             "access_token": self.macaroon_gen.generate_delete_pusher_token(user_id), | ||||||
|             "app_id": app_id, |             "app_id": app_id, | ||||||
|             "pushkey": email_address, |             "pushkey": email_address, | ||||||
|         } |         } | ||||||
|  | |||||||
| @ -73,6 +73,9 @@ class SlavedEventStore(BaseSlavedStore): | |||||||
|     # to reach inside the __dict__ to extract them. |     # to reach inside the __dict__ to extract them. | ||||||
|     get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"] |     get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"] | ||||||
|     get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"] |     get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"] | ||||||
|  |     get_users_who_share_room_with_user = ( | ||||||
|  |         RoomMemberStore.__dict__["get_users_who_share_room_with_user"] | ||||||
|  |     ) | ||||||
|     get_latest_event_ids_in_room = EventFederationStore.__dict__[ |     get_latest_event_ids_in_room = EventFederationStore.__dict__[ | ||||||
|         "get_latest_event_ids_in_room" |         "get_latest_event_ids_in_room" | ||||||
|     ] |     ] | ||||||
|  | |||||||
| @ -330,6 +330,7 @@ class CasTicketServlet(ClientV1RestServlet): | |||||||
|         self.cas_required_attributes = hs.config.cas_required_attributes |         self.cas_required_attributes = hs.config.cas_required_attributes | ||||||
|         self.auth_handler = hs.get_auth_handler() |         self.auth_handler = hs.get_auth_handler() | ||||||
|         self.handlers = hs.get_handlers() |         self.handlers = hs.get_handlers() | ||||||
|  |         self.macaroon_gen = hs.get_macaroon_generator() | ||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def on_GET(self, request): |     def on_GET(self, request): | ||||||
| @ -368,7 +369,9 @@ class CasTicketServlet(ClientV1RestServlet): | |||||||
|                 yield self.handlers.registration_handler.register(localpart=user) |                 yield self.handlers.registration_handler.register(localpart=user) | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         login_token = auth_handler.generate_short_term_login_token(registered_user_id) |         login_token = self.macaroon_gen.generate_short_term_login_token( | ||||||
|  |             registered_user_id | ||||||
|  |         ) | ||||||
|         redirect_url = self.add_login_token_to_redirect_url(client_redirect_url, |         redirect_url = self.add_login_token_to_redirect_url(client_redirect_url, | ||||||
|                                                             login_token) |                                                             login_token) | ||||||
|         request.redirect(redirect_url) |         request.redirect(redirect_url) | ||||||
|  | |||||||
| @ -193,7 +193,7 @@ class KeyChangesServlet(RestServlet): | |||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         defer.returnValue((200, { |         defer.returnValue((200, { | ||||||
|             "changed": changed |             "changed": list(changed), | ||||||
|         })) |         })) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -96,6 +96,7 @@ class RegisterRestServlet(RestServlet): | |||||||
|         self.registration_handler = hs.get_handlers().registration_handler |         self.registration_handler = hs.get_handlers().registration_handler | ||||||
|         self.identity_handler = hs.get_handlers().identity_handler |         self.identity_handler = hs.get_handlers().identity_handler | ||||||
|         self.device_handler = hs.get_device_handler() |         self.device_handler = hs.get_device_handler() | ||||||
|  |         self.macaroon_gen = hs.get_macaroon_generator() | ||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def on_POST(self, request): |     def on_POST(self, request): | ||||||
| @ -436,7 +437,7 @@ class RegisterRestServlet(RestServlet): | |||||||
|             user_id, device_id, initial_display_name |             user_id, device_id, initial_display_name | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         access_token = self.auth_handler.generate_access_token( |         access_token = self.macaroon_gen.generate_access_token( | ||||||
|             user_id, ["guest = true"] |             user_id, ["guest = true"] | ||||||
|         ) |         ) | ||||||
|         defer.returnValue((200, { |         defer.returnValue((200, { | ||||||
|  | |||||||
| @ -37,7 +37,7 @@ from synapse.federation.transport.client import TransportLayerClient | |||||||
| from synapse.federation.transaction_queue import TransactionQueue | from synapse.federation.transaction_queue import TransactionQueue | ||||||
| from synapse.handlers import Handlers | from synapse.handlers import Handlers | ||||||
| from synapse.handlers.appservice import ApplicationServicesHandler | from synapse.handlers.appservice import ApplicationServicesHandler | ||||||
| from synapse.handlers.auth import AuthHandler | from synapse.handlers.auth import AuthHandler, MacaroonGeneartor | ||||||
| from synapse.handlers.devicemessage import DeviceMessageHandler | from synapse.handlers.devicemessage import DeviceMessageHandler | ||||||
| from synapse.handlers.device import DeviceHandler | from synapse.handlers.device import DeviceHandler | ||||||
| from synapse.handlers.e2e_keys import E2eKeysHandler | from synapse.handlers.e2e_keys import E2eKeysHandler | ||||||
| @ -131,6 +131,7 @@ class HomeServer(object): | |||||||
|         'federation_transport_client', |         'federation_transport_client', | ||||||
|         'federation_sender', |         'federation_sender', | ||||||
|         'receipts_handler', |         'receipts_handler', | ||||||
|  |         'macaroon_generator', | ||||||
|     ] |     ] | ||||||
| 
 | 
 | ||||||
|     def __init__(self, hostname, **kwargs): |     def __init__(self, hostname, **kwargs): | ||||||
| @ -213,6 +214,9 @@ class HomeServer(object): | |||||||
|     def build_auth_handler(self): |     def build_auth_handler(self): | ||||||
|         return AuthHandler(self) |         return AuthHandler(self) | ||||||
| 
 | 
 | ||||||
|  |     def build_macaroon_generator(self): | ||||||
|  |         return MacaroonGeneartor(self) | ||||||
|  | 
 | ||||||
|     def build_device_handler(self): |     def build_device_handler(self): | ||||||
|         return DeviceHandler(self) |         return DeviceHandler(self) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -280,6 +280,23 @@ class RoomMemberStore(SQLBaseStore): | |||||||
|             user_id, membership_list=[Membership.JOIN], |             user_id, membership_list=[Membership.JOIN], | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  |     @cachedInlineCallbacks(max_entries=50000, cache_context=True, iterable=True) | ||||||
|  |     def get_users_who_share_room_with_user(self, user_id, cache_context): | ||||||
|  |         """Returns the set of users who share a room with `user_id` | ||||||
|  |         """ | ||||||
|  |         rooms = yield self.get_rooms_for_user( | ||||||
|  |             user_id, on_invalidate=cache_context.invalidate, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         user_who_share_room = set() | ||||||
|  |         for room in rooms: | ||||||
|  |             user_ids = yield self.get_users_in_room( | ||||||
|  |                 room.room_id, on_invalidate=cache_context.invalidate, | ||||||
|  |             ) | ||||||
|  |             user_who_share_room.update(user_ids) | ||||||
|  | 
 | ||||||
|  |         defer.returnValue(user_who_share_room) | ||||||
|  | 
 | ||||||
|     def forget(self, user_id, room_id): |     def forget(self, user_id, room_id): | ||||||
|         """Indicate that user_id wishes to discard history for room_id.""" |         """Indicate that user_id wishes to discard history for room_id.""" | ||||||
|         def f(txn): |         def f(txn): | ||||||
|  | |||||||
| @ -478,6 +478,11 @@ class CacheListDescriptor(object): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): | class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): | ||||||
|  |     # We rely on _CacheContext implementing __eq__ and __hash__ sensibly, | ||||||
|  |     # which namedtuple does for us (i.e. two _CacheContext are the same if | ||||||
|  |     # their caches and keys match). This is important in particular to | ||||||
|  |     # dedupe when we add callbacks to lru cache nodes, otherwise the number | ||||||
|  |     # of callbacks would grow. | ||||||
|     def invalidate(self): |     def invalidate(self): | ||||||
|         self.cache.invalidate(self.key) |         self.cache.invalidate(self.key) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -34,11 +34,10 @@ class AuthTestCase(unittest.TestCase): | |||||||
|         self.hs = yield setup_test_homeserver(handlers=None) |         self.hs = yield setup_test_homeserver(handlers=None) | ||||||
|         self.hs.handlers = AuthHandlers(self.hs) |         self.hs.handlers = AuthHandlers(self.hs) | ||||||
|         self.auth_handler = self.hs.handlers.auth_handler |         self.auth_handler = self.hs.handlers.auth_handler | ||||||
|  |         self.macaroon_generator = self.hs.get_macaroon_generator() | ||||||
| 
 | 
 | ||||||
|     def test_token_is_a_macaroon(self): |     def test_token_is_a_macaroon(self): | ||||||
|         self.hs.config.macaroon_secret_key = "this key is a huge secret" |         token = self.macaroon_generator.generate_access_token("some_user") | ||||||
| 
 |  | ||||||
|         token = self.auth_handler.generate_access_token("some_user") |  | ||||||
|         # Check that we can parse the thing with pymacaroons |         # Check that we can parse the thing with pymacaroons | ||||||
|         macaroon = pymacaroons.Macaroon.deserialize(token) |         macaroon = pymacaroons.Macaroon.deserialize(token) | ||||||
|         # The most basic of sanity checks |         # The most basic of sanity checks | ||||||
| @ -46,10 +45,9 @@ class AuthTestCase(unittest.TestCase): | |||||||
|             self.fail("some_user was not in %s" % macaroon.inspect()) |             self.fail("some_user was not in %s" % macaroon.inspect()) | ||||||
| 
 | 
 | ||||||
|     def test_macaroon_caveats(self): |     def test_macaroon_caveats(self): | ||||||
|         self.hs.config.macaroon_secret_key = "this key is a massive secret" |  | ||||||
|         self.hs.clock.now = 5000 |         self.hs.clock.now = 5000 | ||||||
| 
 | 
 | ||||||
|         token = self.auth_handler.generate_access_token("a_user") |         token = self.macaroon_generator.generate_access_token("a_user") | ||||||
|         macaroon = pymacaroons.Macaroon.deserialize(token) |         macaroon = pymacaroons.Macaroon.deserialize(token) | ||||||
| 
 | 
 | ||||||
|         def verify_gen(caveat): |         def verify_gen(caveat): | ||||||
| @ -74,7 +72,7 @@ class AuthTestCase(unittest.TestCase): | |||||||
|     def test_short_term_login_token_gives_user_id(self): |     def test_short_term_login_token_gives_user_id(self): | ||||||
|         self.hs.clock.now = 1000 |         self.hs.clock.now = 1000 | ||||||
| 
 | 
 | ||||||
|         token = self.auth_handler.generate_short_term_login_token( |         token = self.macaroon_generator.generate_short_term_login_token( | ||||||
|             "a_user", 5000 |             "a_user", 5000 | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
| @ -93,7 +91,7 @@ class AuthTestCase(unittest.TestCase): | |||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|     def test_short_term_login_token_cannot_replace_user_id(self): |     def test_short_term_login_token_cannot_replace_user_id(self): | ||||||
|         token = self.auth_handler.generate_short_term_login_token( |         token = self.macaroon_generator.generate_short_term_login_token( | ||||||
|             "a_user", 5000 |             "a_user", 5000 | ||||||
|         ) |         ) | ||||||
|         macaroon = pymacaroons.Macaroon.deserialize(token) |         macaroon = pymacaroons.Macaroon.deserialize(token) | ||||||
|  | |||||||
| @ -41,15 +41,12 @@ class RegistrationTestCase(unittest.TestCase): | |||||||
|             handlers=None, |             handlers=None, | ||||||
|             http_client=None, |             http_client=None, | ||||||
|             expire_access_token=True) |             expire_access_token=True) | ||||||
|         self.auth_handler = Mock( |         self.macaroon_generator = Mock( | ||||||
|             generate_access_token=Mock(return_value='secret')) |             generate_access_token=Mock(return_value='secret')) | ||||||
|  |         self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator) | ||||||
|         self.hs.handlers = RegistrationHandlers(self.hs) |         self.hs.handlers = RegistrationHandlers(self.hs) | ||||||
|         self.handler = self.hs.get_handlers().registration_handler |         self.handler = self.hs.get_handlers().registration_handler | ||||||
|         self.hs.get_handlers().profile_handler = Mock() |         self.hs.get_handlers().profile_handler = Mock() | ||||||
|         self.mock_handler = Mock(spec=[ |  | ||||||
|             "generate_access_token", |  | ||||||
|         ]) |  | ||||||
|         self.hs.get_auth_handler = Mock(return_value=self.auth_handler) |  | ||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def test_user_is_created_and_logged_in_if_doesnt_exist(self): |     def test_user_is_created_and_logged_in_if_doesnt_exist(self): | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user