mirror of
				https://github.com/matrix-org/synapse.git
				synced 2025-10-25 22:32:03 +02:00 
			
		
		
		
	Merge pull request #392 from matrix-org/markjh/client_config
Add API for setting per user account data at the top level or room level.
This commit is contained in:
		
						commit
						a2922bb944
					
				| @ -50,7 +50,7 @@ class Filtering(object): | ||||
|         # many definitions. | ||||
| 
 | ||||
|         top_level_definitions = [ | ||||
|             "presence" | ||||
|             "presence", "account_data" | ||||
|         ] | ||||
| 
 | ||||
|         room_level_definitions = [ | ||||
| @ -139,6 +139,10 @@ class FilterCollection(object): | ||||
|             self.filter_json.get("presence", {}) | ||||
|         ) | ||||
| 
 | ||||
|         self.account_data = Filter( | ||||
|             self.filter_json.get("account_data", {}) | ||||
|         ) | ||||
| 
 | ||||
|     def timeline_limit(self): | ||||
|         return self.room_timeline_filter.limit() | ||||
| 
 | ||||
| @ -151,6 +155,9 @@ class FilterCollection(object): | ||||
|     def filter_presence(self, events): | ||||
|         return self.presence_filter.filter(events) | ||||
| 
 | ||||
|     def filter_account_data(self, events): | ||||
|         return self.account_data.filter(events) | ||||
| 
 | ||||
|     def filter_room_state(self, events): | ||||
|         return self.room_state_filter.filter(events) | ||||
| 
 | ||||
|  | ||||
| @ -29,9 +29,10 @@ class AccountDataEventSource(object): | ||||
|         last_stream_id = from_key | ||||
| 
 | ||||
|         current_stream_id = yield self.store.get_max_account_data_stream_id() | ||||
|         tags = yield self.store.get_updated_tags(user_id, last_stream_id) | ||||
| 
 | ||||
|         results = [] | ||||
|         tags = yield self.store.get_updated_tags(user_id, last_stream_id) | ||||
| 
 | ||||
|         for room_id, room_tags in tags.items(): | ||||
|             results.append({ | ||||
|                 "type": "m.tag", | ||||
| @ -39,6 +40,24 @@ class AccountDataEventSource(object): | ||||
|                 "room_id": room_id, | ||||
|             }) | ||||
| 
 | ||||
|         account_data, room_account_data = ( | ||||
|             yield self.store.get_updated_account_data_for_user(user_id, last_stream_id) | ||||
|         ) | ||||
| 
 | ||||
|         for account_data_type, content in account_data.items(): | ||||
|             results.append({ | ||||
|                 "type": account_data_type, | ||||
|                 "content": content, | ||||
|             }) | ||||
| 
 | ||||
|         for room_id, account_data in room_account_data.items(): | ||||
|             for account_data_type, content in account_data.items(): | ||||
|                 results.append({ | ||||
|                     "type": account_data_type, | ||||
|                     "content": content, | ||||
|                     "room_id": room_id, | ||||
|                 }) | ||||
| 
 | ||||
|         defer.returnValue((results, current_stream_id)) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|  | ||||
| @ -359,6 +359,10 @@ class MessageHandler(BaseHandler): | ||||
| 
 | ||||
|         tags_by_room = yield self.store.get_tags_for_user(user_id) | ||||
| 
 | ||||
|         account_data, account_data_by_room = ( | ||||
|             yield self.store.get_account_data_for_user(user_id) | ||||
|         ) | ||||
| 
 | ||||
|         public_room_ids = yield self.store.get_public_room_ids() | ||||
| 
 | ||||
|         limit = pagin_config.limit | ||||
| @ -436,14 +440,22 @@ class MessageHandler(BaseHandler): | ||||
|                     for c in current_state.values() | ||||
|                 ] | ||||
| 
 | ||||
|                 account_data = [] | ||||
|                 account_data_events = [] | ||||
|                 tags = tags_by_room.get(event.room_id) | ||||
|                 if tags: | ||||
|                     account_data.append({ | ||||
|                     account_data_events.append({ | ||||
|                         "type": "m.tag", | ||||
|                         "content": {"tags": tags}, | ||||
|                     }) | ||||
|                 d["account_data"] = account_data | ||||
| 
 | ||||
|                 account_data = account_data_by_room.get(event.room_id, {}) | ||||
|                 for account_data_type, content in account_data.items(): | ||||
|                     account_data_events.append({ | ||||
|                         "type": account_data_type, | ||||
|                         "content": content, | ||||
|                     }) | ||||
| 
 | ||||
|                 d["account_data"] = account_data_events | ||||
|             except: | ||||
|                 logger.exception("Failed to get snapshot") | ||||
| 
 | ||||
| @ -456,9 +468,17 @@ class MessageHandler(BaseHandler): | ||||
|                 consumeErrors=True | ||||
|             ).addErrback(unwrapFirstError) | ||||
| 
 | ||||
|         account_data_events = [] | ||||
|         for account_data_type, content in account_data.items(): | ||||
|             account_data_events.append({ | ||||
|                 "type": account_data_type, | ||||
|                 "content": content, | ||||
|             }) | ||||
| 
 | ||||
|         ret = { | ||||
|             "rooms": rooms_ret, | ||||
|             "presence": presence, | ||||
|             "account_data": account_data_events, | ||||
|             "receipts": receipt, | ||||
|             "end": now_token.to_string(), | ||||
|         } | ||||
| @ -498,14 +518,22 @@ class MessageHandler(BaseHandler): | ||||
|                 user_id, room_id, pagin_config, membership, member_event_id, is_guest | ||||
|             ) | ||||
| 
 | ||||
|         account_data = [] | ||||
|         account_data_events = [] | ||||
|         tags = yield self.store.get_tags_for_room(user_id, room_id) | ||||
|         if tags: | ||||
|             account_data.append({ | ||||
|             account_data_events.append({ | ||||
|                 "type": "m.tag", | ||||
|                 "content": {"tags": tags}, | ||||
|             }) | ||||
|         result["account_data"] = account_data | ||||
| 
 | ||||
|         account_data = yield self.store.get_account_data_for_room(user_id, room_id) | ||||
|         for account_data_type, content in account_data.items(): | ||||
|             account_data_events.append({ | ||||
|                 "type": account_data_type, | ||||
|                 "content": content, | ||||
|             }) | ||||
| 
 | ||||
|         result["account_data"] = account_data_events | ||||
| 
 | ||||
|         defer.returnValue(result) | ||||
| 
 | ||||
|  | ||||
| @ -100,6 +100,7 @@ class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [ | ||||
| class SyncResult(collections.namedtuple("SyncResult", [ | ||||
|     "next_batch",  # Token for the next sync | ||||
|     "presence",  # List of presence events for the user. | ||||
|     "account_data",  # List of account_data events for the user. | ||||
|     "joined",  # JoinedSyncResult for each joined room. | ||||
|     "invited",  # InvitedSyncResult for each invited room. | ||||
|     "archived",  # ArchivedSyncResult for each archived room. | ||||
| @ -195,6 +196,12 @@ class SyncHandler(BaseHandler): | ||||
|             ) | ||||
|         ) | ||||
| 
 | ||||
|         account_data, account_data_by_room = ( | ||||
|             yield self.store.get_account_data_for_user( | ||||
|                 sync_config.user.to_string() | ||||
|             ) | ||||
|         ) | ||||
| 
 | ||||
|         tags_by_room = yield self.store.get_tags_for_user( | ||||
|             sync_config.user.to_string() | ||||
|         ) | ||||
| @ -211,6 +218,7 @@ class SyncHandler(BaseHandler): | ||||
|                     timeline_since_token=timeline_since_token, | ||||
|                     ephemeral_by_room=ephemeral_by_room, | ||||
|                     tags_by_room=tags_by_room, | ||||
|                     account_data_by_room=account_data_by_room, | ||||
|                 ) | ||||
|                 joined.append(room_sync) | ||||
|             elif event.membership == Membership.INVITE: | ||||
| @ -230,11 +238,13 @@ class SyncHandler(BaseHandler): | ||||
|                     leave_token=leave_token, | ||||
|                     timeline_since_token=timeline_since_token, | ||||
|                     tags_by_room=tags_by_room, | ||||
|                     account_data_by_room=account_data_by_room, | ||||
|                 ) | ||||
|                 archived.append(room_sync) | ||||
| 
 | ||||
|         defer.returnValue(SyncResult( | ||||
|             presence=presence, | ||||
|             account_data=self.account_data_for_user(account_data), | ||||
|             joined=joined, | ||||
|             invited=invited, | ||||
|             archived=archived, | ||||
| @ -244,7 +254,8 @@ class SyncHandler(BaseHandler): | ||||
|     @defer.inlineCallbacks | ||||
|     def full_state_sync_for_joined_room(self, room_id, sync_config, | ||||
|                                         now_token, timeline_since_token, | ||||
|                                         ephemeral_by_room, tags_by_room): | ||||
|                                         ephemeral_by_room, tags_by_room, | ||||
|                                         account_data_by_room): | ||||
|         """Sync a room for a client which is starting without any state | ||||
|         Returns: | ||||
|             A Deferred JoinedSyncResult. | ||||
| @ -262,19 +273,38 @@ class SyncHandler(BaseHandler): | ||||
|             state=current_state, | ||||
|             ephemeral=ephemeral_by_room.get(room_id, []), | ||||
|             account_data=self.account_data_for_room( | ||||
|                 room_id, tags_by_room | ||||
|                 room_id, tags_by_room, account_data_by_room | ||||
|             ), | ||||
|         )) | ||||
| 
 | ||||
|     def account_data_for_room(self, room_id, tags_by_room): | ||||
|         account_data = [] | ||||
|     def account_data_for_user(self, account_data): | ||||
|         account_data_events = [] | ||||
| 
 | ||||
|         for account_data_type, content in account_data.items(): | ||||
|             account_data_events.append({ | ||||
|                 "type": account_data_type, | ||||
|                 "content": content, | ||||
|             }) | ||||
| 
 | ||||
|         return account_data_events | ||||
| 
 | ||||
|     def account_data_for_room(self, room_id, tags_by_room, account_data_by_room): | ||||
|         account_data_events = [] | ||||
|         tags = tags_by_room.get(room_id) | ||||
|         if tags is not None: | ||||
|             account_data.append({ | ||||
|             account_data_events.append({ | ||||
|                 "type": "m.tag", | ||||
|                 "content": {"tags": tags}, | ||||
|             }) | ||||
|         return account_data | ||||
| 
 | ||||
|         account_data = account_data_by_room.get(room_id, {}) | ||||
|         for account_data_type, content in account_data.items(): | ||||
|             account_data_events.append({ | ||||
|                 "type": account_data_type, | ||||
|                 "content": content, | ||||
|             }) | ||||
| 
 | ||||
|         return account_data_events | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def ephemeral_by_room(self, sync_config, now_token, since_token=None): | ||||
| @ -341,7 +371,8 @@ class SyncHandler(BaseHandler): | ||||
|     @defer.inlineCallbacks | ||||
|     def full_state_sync_for_archived_room(self, room_id, sync_config, | ||||
|                                           leave_event_id, leave_token, | ||||
|                                           timeline_since_token, tags_by_room): | ||||
|                                           timeline_since_token, tags_by_room, | ||||
|                                           account_data_by_room): | ||||
|         """Sync a room for a client which is starting without any state | ||||
|         Returns: | ||||
|             A Deferred JoinedSyncResult. | ||||
| @ -358,7 +389,7 @@ class SyncHandler(BaseHandler): | ||||
|             timeline=batch, | ||||
|             state=leave_state, | ||||
|             account_data=self.account_data_for_room( | ||||
|                 room_id, tags_by_room | ||||
|                 room_id, tags_by_room, account_data_by_room | ||||
|             ), | ||||
|         )) | ||||
| 
 | ||||
| @ -415,6 +446,13 @@ class SyncHandler(BaseHandler): | ||||
|             since_token.account_data_key, | ||||
|         ) | ||||
| 
 | ||||
|         account_data, account_data_by_room = ( | ||||
|             yield self.store.get_updated_account_data_for_user( | ||||
|                 sync_config.user.to_string(), | ||||
|                 since_token.account_data_key, | ||||
|             ) | ||||
|         ) | ||||
| 
 | ||||
|         joined = [] | ||||
|         archived = [] | ||||
|         if len(room_events) <= timeline_limit: | ||||
| @ -469,7 +507,7 @@ class SyncHandler(BaseHandler): | ||||
|                     state=state, | ||||
|                     ephemeral=ephemeral_by_room.get(room_id, []), | ||||
|                     account_data=self.account_data_for_room( | ||||
|                         room_id, tags_by_room | ||||
|                         room_id, tags_by_room, account_data_by_room | ||||
|                     ), | ||||
|                 ) | ||||
|                 logger.debug("Result for room %s: %r", room_id, room_sync) | ||||
| @ -492,14 +530,15 @@ class SyncHandler(BaseHandler): | ||||
|             for room_id in joined_room_ids: | ||||
|                 room_sync = yield self.incremental_sync_with_gap_for_room( | ||||
|                     room_id, sync_config, since_token, now_token, | ||||
|                     ephemeral_by_room, tags_by_room | ||||
|                     ephemeral_by_room, tags_by_room, account_data_by_room | ||||
|                 ) | ||||
|                 if room_sync: | ||||
|                     joined.append(room_sync) | ||||
| 
 | ||||
|         for leave_event in leave_events: | ||||
|             room_sync = yield self.incremental_sync_for_archived_room( | ||||
|                 sync_config, leave_event, since_token, tags_by_room | ||||
|                 sync_config, leave_event, since_token, tags_by_room, | ||||
|                 account_data_by_room | ||||
|             ) | ||||
|             archived.append(room_sync) | ||||
| 
 | ||||
| @ -510,6 +549,7 @@ class SyncHandler(BaseHandler): | ||||
| 
 | ||||
|         defer.returnValue(SyncResult( | ||||
|             presence=presence, | ||||
|             account_data=self.account_data_for_user(account_data), | ||||
|             joined=joined, | ||||
|             invited=invited, | ||||
|             archived=archived, | ||||
| @ -566,7 +606,8 @@ class SyncHandler(BaseHandler): | ||||
|     @defer.inlineCallbacks | ||||
|     def incremental_sync_with_gap_for_room(self, room_id, sync_config, | ||||
|                                            since_token, now_token, | ||||
|                                            ephemeral_by_room, tags_by_room): | ||||
|                                            ephemeral_by_room, tags_by_room, | ||||
|                                            account_data_by_room): | ||||
|         """ Get the incremental delta needed to bring the client up to date for | ||||
|         the room. Gives the client the most recent events and the changes to | ||||
|         state. | ||||
| @ -606,7 +647,7 @@ class SyncHandler(BaseHandler): | ||||
|             state=state, | ||||
|             ephemeral=ephemeral_by_room.get(room_id, []), | ||||
|             account_data=self.account_data_for_room( | ||||
|                 room_id, tags_by_room | ||||
|                 room_id, tags_by_room, account_data_by_room | ||||
|             ), | ||||
|         ) | ||||
| 
 | ||||
| @ -616,7 +657,8 @@ class SyncHandler(BaseHandler): | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def incremental_sync_for_archived_room(self, sync_config, leave_event, | ||||
|                                            since_token, tags_by_room): | ||||
|                                            since_token, tags_by_room, | ||||
|                                            account_data_by_room): | ||||
|         """ Get the incremental delta needed to bring the client up to date for | ||||
|         the archived room. | ||||
|         Returns: | ||||
| @ -654,7 +696,7 @@ class SyncHandler(BaseHandler): | ||||
|             timeline=batch, | ||||
|             state=state_events_delta, | ||||
|             account_data=self.account_data_for_room( | ||||
|                 leave_event.room_id, tags_by_room | ||||
|                 leave_event.room_id, tags_by_room, account_data_by_room | ||||
|             ), | ||||
|         ) | ||||
| 
 | ||||
|  | ||||
| @ -23,6 +23,7 @@ from . import ( | ||||
|     keys, | ||||
|     tokenrefresh, | ||||
|     tags, | ||||
|     account_data, | ||||
| ) | ||||
| 
 | ||||
| from synapse.http.server import JsonResource | ||||
| @ -46,3 +47,4 @@ class ClientV2AlphaRestResource(JsonResource): | ||||
|         keys.register_servlets(hs, client_resource) | ||||
|         tokenrefresh.register_servlets(hs, client_resource) | ||||
|         tags.register_servlets(hs, client_resource) | ||||
|         account_data.register_servlets(hs, client_resource) | ||||
|  | ||||
							
								
								
									
										111
									
								
								synapse/rest/client/v2_alpha/account_data.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										111
									
								
								synapse/rest/client/v2_alpha/account_data.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,111 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2015 OpenMarket Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from ._base import client_v2_patterns | ||||
| 
 | ||||
| from synapse.http.servlet import RestServlet | ||||
| from synapse.api.errors import AuthError, SynapseError | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| import logging | ||||
| 
 | ||||
| import simplejson as json | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| class AccountDataServlet(RestServlet): | ||||
|     """ | ||||
|     PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1 | ||||
|     """ | ||||
|     PATTERNS = client_v2_patterns( | ||||
|         "/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)" | ||||
|     ) | ||||
| 
 | ||||
|     def __init__(self, hs): | ||||
|         super(AccountDataServlet, self).__init__() | ||||
|         self.auth = hs.get_auth() | ||||
|         self.store = hs.get_datastore() | ||||
|         self.notifier = hs.get_notifier() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_PUT(self, request, user_id, account_data_type): | ||||
|         auth_user, _, _ = yield self.auth.get_user_by_req(request) | ||||
|         if user_id != auth_user.to_string(): | ||||
|             raise AuthError(403, "Cannot add account data for other users.") | ||||
| 
 | ||||
|         try: | ||||
|             content_bytes = request.content.read() | ||||
|             body = json.loads(content_bytes) | ||||
|         except: | ||||
|             raise SynapseError(400, "Invalid JSON") | ||||
| 
 | ||||
|         max_id = yield self.store.add_account_data_for_user( | ||||
|             user_id, account_data_type, body | ||||
|         ) | ||||
| 
 | ||||
|         yield self.notifier.on_new_event( | ||||
|             "account_data_key", max_id, users=[user_id] | ||||
|         ) | ||||
| 
 | ||||
|         defer.returnValue((200, {})) | ||||
| 
 | ||||
| 
 | ||||
| class RoomAccountDataServlet(RestServlet): | ||||
|     """ | ||||
|     PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 | ||||
|     """ | ||||
|     PATTERNS = client_v2_patterns( | ||||
|         "/user/(?P<user_id>[^/]*)" | ||||
|         "/rooms/(?P<room_id>[^/]*)" | ||||
|         "/account_data/(?P<account_data_type>[^/]*)" | ||||
|     ) | ||||
| 
 | ||||
|     def __init__(self, hs): | ||||
|         super(RoomAccountDataServlet, self).__init__() | ||||
|         self.auth = hs.get_auth() | ||||
|         self.store = hs.get_datastore() | ||||
|         self.notifier = hs.get_notifier() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_PUT(self, request, user_id, room_id, account_data_type): | ||||
|         auth_user, _, _ = yield self.auth.get_user_by_req(request) | ||||
|         if user_id != auth_user.to_string(): | ||||
|             raise AuthError(403, "Cannot add account data for other users.") | ||||
| 
 | ||||
|         try: | ||||
|             content_bytes = request.content.read() | ||||
|             body = json.loads(content_bytes) | ||||
|         except: | ||||
|             raise SynapseError(400, "Invalid JSON") | ||||
| 
 | ||||
|         if not isinstance(body, dict): | ||||
|             raise ValueError("Expected a JSON object") | ||||
| 
 | ||||
|         max_id = yield self.store.add_account_data_to_room( | ||||
|             user_id, room_id, account_data_type, body | ||||
|         ) | ||||
| 
 | ||||
|         yield self.notifier.on_new_event( | ||||
|             "account_data_key", max_id, users=[user_id] | ||||
|         ) | ||||
| 
 | ||||
|         defer.returnValue((200, {})) | ||||
| 
 | ||||
| 
 | ||||
| def register_servlets(hs, http_server): | ||||
|     AccountDataServlet(hs).register(http_server) | ||||
|     RoomAccountDataServlet(hs).register(http_server) | ||||
| @ -144,6 +144,9 @@ class SyncRestServlet(RestServlet): | ||||
|         ) | ||||
| 
 | ||||
|         response_content = { | ||||
|             "account_data": self.encode_account_data( | ||||
|                 sync_result.account_data, filter, time_now | ||||
|             ), | ||||
|             "presence": self.encode_presence( | ||||
|                 sync_result.presence, filter, time_now | ||||
|             ), | ||||
| @ -165,6 +168,9 @@ class SyncRestServlet(RestServlet): | ||||
|             formatted.append(event) | ||||
|         return {"events": filter.filter_presence(formatted)} | ||||
| 
 | ||||
|     def encode_account_data(self, events, filter, time_now): | ||||
|         return {"events": filter.filter_account_data(events)} | ||||
| 
 | ||||
|     def encode_joined(self, rooms, filter, time_now, token_id): | ||||
|         """ | ||||
|         Encode the joined rooms in a sync result | ||||
|  | ||||
| @ -42,6 +42,7 @@ from .end_to_end_keys import EndToEndKeyStore | ||||
| from .receipts import ReceiptsStore | ||||
| from .search import SearchStore | ||||
| from .tags import TagsStore | ||||
| from .account_data import AccountDataStore | ||||
| 
 | ||||
| 
 | ||||
| import logging | ||||
| @ -73,6 +74,7 @@ class DataStore(RoomMemberStore, RoomStore, | ||||
|                 EndToEndKeyStore, | ||||
|                 SearchStore, | ||||
|                 TagsStore, | ||||
|                 AccountDataStore, | ||||
|                 ): | ||||
| 
 | ||||
|     def __init__(self, hs): | ||||
|  | ||||
							
								
								
									
										211
									
								
								synapse/storage/account_data.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										211
									
								
								synapse/storage/account_data.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,211 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2014, 2015 OpenMarket Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from ._base import SQLBaseStore | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| import ujson as json | ||||
| import logging | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| class AccountDataStore(SQLBaseStore): | ||||
| 
 | ||||
|     def get_account_data_for_user(self, user_id): | ||||
|         """Get all the client account_data for a user. | ||||
| 
 | ||||
|         Args: | ||||
|             user_id(str): The user to get the account_data for. | ||||
|         Returns: | ||||
|             A deferred pair of a dict of global account_data and a dict | ||||
|             mapping from room_id string to per room account_data dicts. | ||||
|         """ | ||||
| 
 | ||||
|         def get_account_data_for_user_txn(txn): | ||||
|             rows = self._simple_select_list_txn( | ||||
|                 txn, "account_data", {"user_id": user_id}, | ||||
|                 ["account_data_type", "content"] | ||||
|             ) | ||||
| 
 | ||||
|             global_account_data = { | ||||
|                 row["account_data_type"]: json.loads(row["content"]) for row in rows | ||||
|             } | ||||
| 
 | ||||
|             rows = self._simple_select_list_txn( | ||||
|                 txn, "room_account_data", {"user_id": user_id}, | ||||
|                 ["room_id", "account_data_type", "content"] | ||||
|             ) | ||||
| 
 | ||||
|             by_room = {} | ||||
|             for row in rows: | ||||
|                 room_data = by_room.setdefault(row["room_id"], {}) | ||||
|                 room_data[row["account_data_type"]] = json.loads(row["content"]) | ||||
| 
 | ||||
|             return (global_account_data, by_room) | ||||
| 
 | ||||
|         return self.runInteraction( | ||||
|             "get_account_data_for_user", get_account_data_for_user_txn | ||||
|         ) | ||||
| 
 | ||||
|     def get_account_data_for_room(self, user_id, room_id): | ||||
|         """Get all the client account_data for a user for a room. | ||||
| 
 | ||||
|         Args: | ||||
|             user_id(str): The user to get the account_data for. | ||||
|             room_id(str): The room to get the account_data for. | ||||
|         Returns: | ||||
|             A deferred dict of the room account_data | ||||
|         """ | ||||
|         def get_account_data_for_room_txn(txn): | ||||
|             rows = self._simple_select_list_txn( | ||||
|                 txn, "room_account_data", {"user_id": user_id, "room_id": room_id}, | ||||
|                 ["account_data_type", "content"] | ||||
|             ) | ||||
| 
 | ||||
|             return { | ||||
|                 row["account_data_type"]: json.loads(row["content"]) for row in rows | ||||
|             } | ||||
| 
 | ||||
|         return self.runInteraction( | ||||
|             "get_account_data_for_room", get_account_data_for_room_txn | ||||
|         ) | ||||
| 
 | ||||
|     def get_updated_account_data_for_user(self, user_id, stream_id): | ||||
|         """Get all the client account_data for a that's changed. | ||||
| 
 | ||||
|         Args: | ||||
|             user_id(str): The user to get the account_data for. | ||||
|             stream_id(int): The point in the stream since which to get updates | ||||
|         Returns: | ||||
|             A deferred pair of a dict of global account_data and a dict | ||||
|             mapping from room_id string to per room account_data dicts. | ||||
|         """ | ||||
| 
 | ||||
|         def get_updated_account_data_for_user_txn(txn): | ||||
|             sql = ( | ||||
|                 "SELECT account_data_type, content FROM account_data" | ||||
|                 " WHERE user_id = ? AND stream_id > ?" | ||||
|             ) | ||||
| 
 | ||||
|             txn.execute(sql, (user_id, stream_id)) | ||||
| 
 | ||||
|             global_account_data = { | ||||
|                 row[0]: json.loads(row[1]) for row in txn.fetchall() | ||||
|             } | ||||
| 
 | ||||
|             sql = ( | ||||
|                 "SELECT room_id, account_data_type, content FROM room_account_data" | ||||
|                 " WHERE user_id = ? AND stream_id > ?" | ||||
|             ) | ||||
| 
 | ||||
|             txn.execute(sql, (user_id, stream_id)) | ||||
| 
 | ||||
|             account_data_by_room = {} | ||||
|             for row in txn.fetchall(): | ||||
|                 room_account_data = account_data_by_room.setdefault(row[0], {}) | ||||
|                 room_account_data[row[1]] = json.loads(row[2]) | ||||
| 
 | ||||
|             return (global_account_data, account_data_by_room) | ||||
| 
 | ||||
|         return self.runInteraction( | ||||
|             "get_updated_account_data_for_user", get_updated_account_data_for_user_txn | ||||
|         ) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def add_account_data_to_room(self, user_id, room_id, account_data_type, content): | ||||
|         """Add some account_data to a room for a user. | ||||
|         Args: | ||||
|             user_id(str): The user to add a tag for. | ||||
|             room_id(str): The room to add a tag for. | ||||
|             account_data_type(str): The type of account_data to add. | ||||
|             content(dict): A json object to associate with the tag. | ||||
|         Returns: | ||||
|             A deferred that completes once the account_data has been added. | ||||
|         """ | ||||
|         content_json = json.dumps(content) | ||||
| 
 | ||||
|         def add_account_data_txn(txn, next_id): | ||||
|             self._simple_upsert_txn( | ||||
|                 txn, | ||||
|                 table="room_account_data", | ||||
|                 keyvalues={ | ||||
|                     "user_id": user_id, | ||||
|                     "room_id": room_id, | ||||
|                     "account_data_type": account_data_type, | ||||
|                 }, | ||||
|                 values={ | ||||
|                     "stream_id": next_id, | ||||
|                     "content": content_json, | ||||
|                 } | ||||
|             ) | ||||
|             self._update_max_stream_id(txn, next_id) | ||||
| 
 | ||||
|         with (yield self._account_data_id_gen.get_next(self)) as next_id: | ||||
|             yield self.runInteraction( | ||||
|                 "add_room_account_data", add_account_data_txn, next_id | ||||
|             ) | ||||
| 
 | ||||
|         result = yield self._account_data_id_gen.get_max_token(self) | ||||
|         defer.returnValue(result) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def add_account_data_for_user(self, user_id, account_data_type, content): | ||||
|         """Add some account_data to a room for a user. | ||||
|         Args: | ||||
|             user_id(str): The user to add a tag for. | ||||
|             account_data_type(str): The type of account_data to add. | ||||
|             content(dict): A json object to associate with the tag. | ||||
|         Returns: | ||||
|             A deferred that completes once the account_data has been added. | ||||
|         """ | ||||
|         content_json = json.dumps(content) | ||||
| 
 | ||||
|         def add_account_data_txn(txn, next_id): | ||||
|             self._simple_upsert_txn( | ||||
|                 txn, | ||||
|                 table="account_data", | ||||
|                 keyvalues={ | ||||
|                     "user_id": user_id, | ||||
|                     "account_data_type": account_data_type, | ||||
|                 }, | ||||
|                 values={ | ||||
|                     "stream_id": next_id, | ||||
|                     "content": content_json, | ||||
|                 } | ||||
|             ) | ||||
|             self._update_max_stream_id(txn, next_id) | ||||
| 
 | ||||
|         with (yield self._account_data_id_gen.get_next(self)) as next_id: | ||||
|             yield self.runInteraction( | ||||
|                 "add_user_account_data", add_account_data_txn, next_id | ||||
|             ) | ||||
| 
 | ||||
|         result = yield self._account_data_id_gen.get_max_token(self) | ||||
|         defer.returnValue(result) | ||||
| 
 | ||||
|     def _update_max_stream_id(self, txn, next_id): | ||||
|         """Update the max stream_id | ||||
| 
 | ||||
|         Args: | ||||
|             txn: The database cursor | ||||
|             next_id(int): The the revision to advance to. | ||||
|         """ | ||||
|         update_max_id_sql = ( | ||||
|             "UPDATE account_data_max_stream_id" | ||||
|             " SET stream_id = ?" | ||||
|             " WHERE stream_id < ?" | ||||
|         ) | ||||
|         txn.execute(update_max_id_sql, (next_id, next_id)) | ||||
| @ -15,3 +15,26 @@ | ||||
| 
 | ||||
| 
 | ||||
| ALTER TABLE private_user_data_max_stream_id RENAME TO account_data_max_stream_id; | ||||
| 
 | ||||
| 
 | ||||
| CREATE TABLE IF NOT EXISTS account_data( | ||||
|     user_id TEXT NOT NULL, | ||||
|     account_data_type TEXT NOT NULL, -- The type of the account_data. | ||||
|     stream_id BIGINT NOT NULL, -- The version of the account_data. | ||||
|     content TEXT NOT NULL,  -- The JSON content of the account_data | ||||
|     CONSTRAINT account_data_uniqueness UNIQUE (user_id, account_data_type) | ||||
| ); | ||||
| 
 | ||||
| 
 | ||||
| CREATE TABLE IF NOT EXISTS room_account_data( | ||||
|     user_id TEXT NOT NULL, | ||||
|     room_id TEXT NOT NULL, | ||||
|     account_data_type TEXT NOT NULL, -- The type of the account_data. | ||||
|     stream_id BIGINT NOT NULL, -- The version of the account_data. | ||||
|     content TEXT NOT NULL,  -- The JSON content of the account_data | ||||
|     CONSTRAINT room_account_data_uniqueness UNIQUE (user_id, room_id, account_data_type) | ||||
| ); | ||||
| 
 | ||||
| 
 | ||||
| CREATE INDEX account_data_stream_id on account_data(user_id, stream_id); | ||||
| CREATE INDEX room_account_data_stream_id on room_account_data(user_id, stream_id); | ||||
|  | ||||
| @ -48,8 +48,8 @@ class TagsStore(SQLBaseStore): | ||||
|         Args: | ||||
|             user_id(str): The user to get the tags for. | ||||
|         Returns: | ||||
|             A deferred dict mapping from room_id strings to lists of tag | ||||
|             strings. | ||||
|             A deferred dict mapping from room_id strings to dicts mapping from | ||||
|             tag strings to tag content. | ||||
|         """ | ||||
| 
 | ||||
|         deferred = self._simple_select_list( | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user