mirror of
				https://github.com/matrix-org/synapse.git
				synced 2025-10-31 00:01:33 +01:00 
			
		
		
		
	Merge branch 'release-v0.20.0' of github.com:matrix-org/synapse
This commit is contained in:
		
						commit
						4902db1fc9
					
				
							
								
								
									
										53
									
								
								CHANGES.rst
									
									
									
									
									
								
							
							
						
						
									
										53
									
								
								CHANGES.rst
									
									
									
									
									
								
							| @ -1,3 +1,56 @@ | |||||||
|  | Changes in synapse v0.20.0 (2017-04-11) | ||||||
|  | ======================================= | ||||||
|  | 
 | ||||||
|  | Bug fixes: | ||||||
|  | 
 | ||||||
|  | * Fix joining rooms over federation where not all servers in the room saw the | ||||||
|  |   new server had joined (PR #2094) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | Changes in synapse v0.20.0-rc1 (2017-03-30) | ||||||
|  | =========================================== | ||||||
|  | 
 | ||||||
|  | Features: | ||||||
|  | 
 | ||||||
|  | * Add delete_devices API (PR #1993) | ||||||
|  | * Add phone number registration/login support (PR #1994, #2055) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | Changes: | ||||||
|  | 
 | ||||||
|  | * Use JSONSchema for validation of filters. Thanks @pik! (PR #1783) | ||||||
|  | * Reread log config on SIGHUP (PR #1982) | ||||||
|  | * Speed up public room list (PR #1989) | ||||||
|  | * Add helpful texts to logger config options (PR #1990) | ||||||
|  | * Minor ``/sync`` performance improvements. (PR #2002, #2013, #2022) | ||||||
|  | * Add some debug to help diagnose weird federation issue (PR #2035) | ||||||
|  | * Correctly limit retries for all federation requests (PR #2050, #2061) | ||||||
|  | * Don't lock table when persisting new one time keys (PR #2053) | ||||||
|  | * Reduce some CPU work on DB threads (PR #2054) | ||||||
|  | * Cache hosts in room (PR #2060) | ||||||
|  | * Batch sending of device list pokes (PR #2063) | ||||||
|  | * Speed up persist event path in certain edge cases (PR #2070) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | Bug fixes: | ||||||
|  | 
 | ||||||
|  | * Fix bug where current_state_events renamed to current_state_ids (PR #1849) | ||||||
|  | * Fix routing loop when fetching remote media (PR #1992) | ||||||
|  | * Fix current_state_events table to not lie (PR #1996) | ||||||
|  | * Fix CAS login to handle PartialDownloadError (PR #1997) | ||||||
|  | * Fix assertion to stop transaction queue getting wedged (PR #2010) | ||||||
|  | * Fix presence to fallback to last_active_ts if it beats the last sync time. | ||||||
|  |   Thanks @Half-Shot! (PR #2014) | ||||||
|  | * Fix bug when federation received a PDU while a room join is in progress (PR | ||||||
|  |   #2016) | ||||||
|  | * Fix resetting state on rejected events (PR #2025) | ||||||
|  | * Fix installation issues in readme. Thanks @ricco386 (PR #2037) | ||||||
|  | * Fix caching of remote servers' signature keys (PR #2042) | ||||||
|  | * Fix some leaking log context (PR #2048, #2049, #2057, #2058) | ||||||
|  | * Fix rejection of invites not reaching sync (PR #2056) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| Changes in synapse v0.19.3 (2017-03-20) | Changes in synapse v0.19.3 (2017-03-20) | ||||||
| ======================================= | ======================================= | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -146,6 +146,7 @@ To install the synapse homeserver run:: | |||||||
| 
 | 
 | ||||||
|     virtualenv -p python2.7 ~/.synapse |     virtualenv -p python2.7 ~/.synapse | ||||||
|     source ~/.synapse/bin/activate |     source ~/.synapse/bin/activate | ||||||
|  |     pip install --upgrade pip | ||||||
|     pip install --upgrade setuptools |     pip install --upgrade setuptools | ||||||
|     pip install https://github.com/matrix-org/synapse/tarball/master |     pip install https://github.com/matrix-org/synapse/tarball/master | ||||||
| 
 | 
 | ||||||
| @ -228,6 +229,7 @@ To get started, it is easiest to use the command line to register new users:: | |||||||
|     New user localpart: erikj |     New user localpart: erikj | ||||||
|     Password: |     Password: | ||||||
|     Confirm password: |     Confirm password: | ||||||
|  |     Make admin [no]: | ||||||
|     Success! |     Success! | ||||||
| 
 | 
 | ||||||
| This process uses a setting ``registration_shared_secret`` in | This process uses a setting ``registration_shared_secret`` in | ||||||
| @ -808,7 +810,7 @@ directory of your choice:: | |||||||
| Synapse has a number of external dependencies, that are easiest | Synapse has a number of external dependencies, that are easiest | ||||||
| to install using pip and a virtualenv:: | to install using pip and a virtualenv:: | ||||||
| 
 | 
 | ||||||
|     virtualenv env |     virtualenv -p python2.7 env | ||||||
|     source env/bin/activate |     source env/bin/activate | ||||||
|     python synapse/python_dependencies.py | xargs pip install |     python synapse/python_dependencies.py | xargs pip install | ||||||
|     pip install lxml mock |     pip install lxml mock | ||||||
|  | |||||||
| @ -39,7 +39,9 @@ loggers: | |||||||
|     synapse: |     synapse: | ||||||
|         level: INFO |         level: INFO | ||||||
| 
 | 
 | ||||||
|     synapse.storage: |     synapse.storage.SQL: | ||||||
|  |         # beware: increasing this to DEBUG will make synapse log sensitive | ||||||
|  |         # information such as access tokens. | ||||||
|         level: INFO |         level: INFO | ||||||
| 
 | 
 | ||||||
|     # example of enabling debugging for a component: |     # example of enabling debugging for a component: | ||||||
|  | |||||||
| @ -1,10 +1,446 @@ | |||||||
| What do I do about "Unexpected logging context" debug log-lines everywhere? | Log contexts | ||||||
|  | ============ | ||||||
| 
 | 
 | ||||||
| <Mjark> The logging context lives in thread local storage | .. contents:: | ||||||
| <Mjark> Sometimes it gets out of sync with what it should actually be, usually because something scheduled something to run on the reactor without preserving the logging context.  |  | ||||||
| <Matthew> what is the impact of it getting out of sync? and how and when should we preserve log context? |  | ||||||
| <Mjark> The impact is that some of the CPU and database metrics will be under-reported, and some log lines will be mis-attributed. |  | ||||||
| <Mjark> It should happen auto-magically in all the APIs that do IO or otherwise defer to the reactor. |  | ||||||
| <Erik> Mjark: the other place is if we branch, e.g. using defer.gatherResults |  | ||||||
| 
 | 
 | ||||||
| Unanswered: how and when should we preserve log context? | To help track the processing of individual requests, synapse uses a | ||||||
|  | 'log context' to track which request it is handling at any given moment. This | ||||||
|  | is done via a thread-local variable; a ``logging.Filter`` is then used to fish | ||||||
|  | the information back out of the thread-local variable and add it to each log | ||||||
|  | record. | ||||||
|  | 
 | ||||||
|  | Logcontexts are also used for CPU and database accounting, so that we can track | ||||||
|  | which requests were responsible for high CPU use or database activity. | ||||||
|  | 
 | ||||||
|  | The ``synapse.util.logcontext`` module provides a facilities for managing the | ||||||
|  | current log context (as well as providing the ``LoggingContextFilter`` class). | ||||||
|  | 
 | ||||||
|  | Deferreds make the whole thing complicated, so this document describes how it | ||||||
|  | all works, and how to write code which follows the rules. | ||||||
|  | 
 | ||||||
|  | Logcontexts without Deferreds | ||||||
|  | ----------------------------- | ||||||
|  | 
 | ||||||
|  | In the absence of any Deferred voodoo, things are simple enough. As with any | ||||||
|  | code of this nature, the rule is that our function should leave things as it | ||||||
|  | found them: | ||||||
|  | 
 | ||||||
|  | .. code:: python | ||||||
|  | 
 | ||||||
|  |     from synapse.util import logcontext         # omitted from future snippets | ||||||
|  | 
 | ||||||
|  |     def handle_request(request_id): | ||||||
|  |         request_context = logcontext.LoggingContext() | ||||||
|  | 
 | ||||||
|  |         calling_context = logcontext.LoggingContext.current_context() | ||||||
|  |         logcontext.LoggingContext.set_current_context(request_context) | ||||||
|  |         try: | ||||||
|  |             request_context.request = request_id | ||||||
|  |             do_request_handling() | ||||||
|  |             logger.debug("finished") | ||||||
|  |         finally: | ||||||
|  |             logcontext.LoggingContext.set_current_context(calling_context) | ||||||
|  | 
 | ||||||
|  |     def do_request_handling(): | ||||||
|  |         logger.debug("phew")  # this will be logged against request_id | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | LoggingContext implements the context management methods, so the above can be | ||||||
|  | written much more succinctly as: | ||||||
|  | 
 | ||||||
|  | .. code:: python | ||||||
|  | 
 | ||||||
|  |     def handle_request(request_id): | ||||||
|  |         with logcontext.LoggingContext() as request_context: | ||||||
|  |             request_context.request = request_id | ||||||
|  |             do_request_handling() | ||||||
|  |             logger.debug("finished") | ||||||
|  | 
 | ||||||
|  |     def do_request_handling(): | ||||||
|  |         logger.debug("phew") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | Using logcontexts with Deferreds | ||||||
|  | -------------------------------- | ||||||
|  | 
 | ||||||
|  | Deferreds — and in particular, ``defer.inlineCallbacks`` — break | ||||||
|  | the linear flow of code so that there is no longer a single entry point where | ||||||
|  | we should set the logcontext and a single exit point where we should remove it. | ||||||
|  | 
 | ||||||
|  | Consider the example above, where ``do_request_handling`` needs to do some | ||||||
|  | blocking operation, and returns a deferred: | ||||||
|  | 
 | ||||||
|  | .. code:: python | ||||||
|  | 
 | ||||||
|  |     @defer.inlineCallbacks | ||||||
|  |     def handle_request(request_id): | ||||||
|  |         with logcontext.LoggingContext() as request_context: | ||||||
|  |             request_context.request = request_id | ||||||
|  |             yield do_request_handling() | ||||||
|  |             logger.debug("finished") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | In the above flow: | ||||||
|  | 
 | ||||||
|  | * The logcontext is set | ||||||
|  | * ``do_request_handling`` is called, and returns a deferred | ||||||
|  | * ``handle_request`` yields the deferred | ||||||
|  | * The ``inlineCallbacks`` wrapper of ``handle_request`` returns a deferred | ||||||
|  | 
 | ||||||
|  | So we have stopped processing the request (and will probably go on to start | ||||||
|  | processing the next), without clearing the logcontext. | ||||||
|  | 
 | ||||||
|  | To circumvent this problem, synapse code assumes that, wherever you have a | ||||||
|  | deferred, you will want to yield on it. To that end, whereever functions return | ||||||
|  | a deferred, we adopt the following conventions: | ||||||
|  | 
 | ||||||
|  | **Rules for functions returning deferreds:** | ||||||
|  | 
 | ||||||
|  |   * If the deferred is already complete, the function returns with the same | ||||||
|  |     logcontext it started with. | ||||||
|  |   * If the deferred is incomplete, the function clears the logcontext before | ||||||
|  |     returning; when the deferred completes, it restores the logcontext before | ||||||
|  |     running any callbacks. | ||||||
|  | 
 | ||||||
|  | That sounds complicated, but actually it means a lot of code (including the | ||||||
|  | example above) "just works". There are two cases: | ||||||
|  | 
 | ||||||
|  | * If ``do_request_handling`` returns a completed deferred, then the logcontext | ||||||
|  |   will still be in place. In this case, execution will continue immediately | ||||||
|  |   after the ``yield``; the "finished" line will be logged against the right | ||||||
|  |   context, and the ``with`` block restores the original context before we | ||||||
|  |   return to the caller. | ||||||
|  | 
 | ||||||
|  | * If the returned deferred is incomplete, ``do_request_handling`` clears the | ||||||
|  |   logcontext before returning. The logcontext is therefore clear when | ||||||
|  |   ``handle_request`` yields the deferred. At that point, the ``inlineCallbacks`` | ||||||
|  |   wrapper adds a callback to the deferred, and returns another (incomplete) | ||||||
|  |   deferred to the caller, and it is safe to begin processing the next request. | ||||||
|  | 
 | ||||||
|  |   Once ``do_request_handling``'s deferred completes, it will reinstate the | ||||||
|  |   logcontext, before running the callback added by the ``inlineCallbacks`` | ||||||
|  |   wrapper. That callback runs the second half of ``handle_request``, so again | ||||||
|  |   the "finished" line will be logged against the right | ||||||
|  |   context, and the ``with`` block restores the original context. | ||||||
|  | 
 | ||||||
|  | As an aside, it's worth noting that ``handle_request`` follows our rules - | ||||||
|  | though that only matters if the caller has its own logcontext which it cares | ||||||
|  | about. | ||||||
|  | 
 | ||||||
|  | The following sections describe pitfalls and helpful patterns when implementing | ||||||
|  | these rules. | ||||||
|  | 
 | ||||||
|  | Always yield your deferreds | ||||||
|  | --------------------------- | ||||||
|  | 
 | ||||||
|  | Whenever you get a deferred back from a function, you should ``yield`` on it | ||||||
|  | as soon as possible. (Returning it directly to your caller is ok too, if you're | ||||||
|  | not doing ``inlineCallbacks``.) Do not pass go; do not do any logging; do not | ||||||
|  | call any other functions. | ||||||
|  | 
 | ||||||
|  | .. code:: python | ||||||
|  | 
 | ||||||
|  |     @defer.inlineCallbacks | ||||||
|  |     def fun(): | ||||||
|  |         logger.debug("starting") | ||||||
|  |         yield do_some_stuff()       # just like this | ||||||
|  | 
 | ||||||
|  |         d = more_stuff() | ||||||
|  |         result = yield d            # also fine, of course | ||||||
|  | 
 | ||||||
|  |         defer.returnValue(result) | ||||||
|  | 
 | ||||||
|  |     def nonInlineCallbacksFun(): | ||||||
|  |         logger.debug("just a wrapper really") | ||||||
|  |         return do_some_stuff()      # this is ok too - the caller will yield on | ||||||
|  |                                     # it anyway. | ||||||
|  | 
 | ||||||
|  | Provided this pattern is followed all the way back up to the callchain to where | ||||||
|  | the logcontext was set, this will make things work out ok: provided | ||||||
|  | ``do_some_stuff`` and ``more_stuff`` follow the rules above, then so will | ||||||
|  | ``fun`` (as wrapped by ``inlineCallbacks``) and ``nonInlineCallbacksFun``. | ||||||
|  | 
 | ||||||
|  | It's all too easy to forget to ``yield``: for instance if we forgot that | ||||||
|  | ``do_some_stuff`` returned a deferred, we might plough on regardless. This | ||||||
|  | leads to a mess; it will probably work itself out eventually, but not before | ||||||
|  | a load of stuff has been logged against the wrong content. (Normally, other | ||||||
|  | things will break, more obviously, if you forget to ``yield``, so this tends | ||||||
|  | not to be a major problem in practice.) | ||||||
|  | 
 | ||||||
|  | Of course sometimes you need to do something a bit fancier with your Deferreds | ||||||
|  | - not all code follows the linear A-then-B-then-C pattern. Notes on | ||||||
|  | implementing more complex patterns are in later sections. | ||||||
|  | 
 | ||||||
|  | Where you create a new Deferred, make it follow the rules | ||||||
|  | --------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | Most of the time, a Deferred comes from another synapse function. Sometimes, | ||||||
|  | though, we need to make up a new Deferred, or we get a Deferred back from | ||||||
|  | external code. We need to make it follow our rules. | ||||||
|  | 
 | ||||||
|  | The easy way to do it is with a combination of ``defer.inlineCallbacks``, and | ||||||
|  | ``logcontext.PreserveLoggingContext``. Suppose we want to implement ``sleep``, | ||||||
|  | which returns a deferred which will run its callbacks after a given number of | ||||||
|  | seconds. That might look like: | ||||||
|  | 
 | ||||||
|  | .. code:: python | ||||||
|  | 
 | ||||||
|  |     # not a logcontext-rules-compliant function | ||||||
|  |     def get_sleep_deferred(seconds): | ||||||
|  |         d = defer.Deferred() | ||||||
|  |         reactor.callLater(seconds, d.callback, None) | ||||||
|  |         return d | ||||||
|  | 
 | ||||||
|  | That doesn't follow the rules, but we can fix it by wrapping it with | ||||||
|  | ``PreserveLoggingContext`` and ``yield`` ing on it: | ||||||
|  | 
 | ||||||
|  | .. code:: python | ||||||
|  | 
 | ||||||
|  |     @defer.inlineCallbacks | ||||||
|  |     def sleep(seconds): | ||||||
|  |         with PreserveLoggingContext(): | ||||||
|  |             yield get_sleep_deferred(seconds) | ||||||
|  | 
 | ||||||
|  | This technique works equally for external functions which return deferreds, | ||||||
|  | or deferreds we have made ourselves. | ||||||
|  | 
 | ||||||
|  | You can also use ``logcontext.make_deferred_yieldable``, which just does the | ||||||
|  | boilerplate for you, so the above could be written: | ||||||
|  | 
 | ||||||
|  | .. code:: python | ||||||
|  | 
 | ||||||
|  |     def sleep(seconds): | ||||||
|  |         return logcontext.make_deferred_yieldable(get_sleep_deferred(seconds)) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | Fire-and-forget | ||||||
|  | --------------- | ||||||
|  | 
 | ||||||
|  | Sometimes you want to fire off a chain of execution, but not wait for its | ||||||
|  | result. That might look a bit like this: | ||||||
|  | 
 | ||||||
|  | .. code:: python | ||||||
|  | 
 | ||||||
|  |     @defer.inlineCallbacks | ||||||
|  |     def do_request_handling(): | ||||||
|  |         yield foreground_operation() | ||||||
|  | 
 | ||||||
|  |         # *don't* do this | ||||||
|  |         background_operation() | ||||||
|  | 
 | ||||||
|  |         logger.debug("Request handling complete") | ||||||
|  | 
 | ||||||
|  |     @defer.inlineCallbacks | ||||||
|  |     def background_operation(): | ||||||
|  |         yield first_background_step() | ||||||
|  |         logger.debug("Completed first step") | ||||||
|  |         yield second_background_step() | ||||||
|  |         logger.debug("Completed second step") | ||||||
|  | 
 | ||||||
|  | The above code does a couple of steps in the background after | ||||||
|  | ``do_request_handling`` has finished. The log lines are still logged against | ||||||
|  | the ``request_context`` logcontext, which may or may not be desirable. There | ||||||
|  | are two big problems with the above, however. The first problem is that, if | ||||||
|  | ``background_operation`` returns an incomplete Deferred, it will expect its | ||||||
|  | caller to ``yield`` immediately, so will have cleared the logcontext. In this | ||||||
|  | example, that means that 'Request handling complete' will be logged without any | ||||||
|  | context. | ||||||
|  | 
 | ||||||
|  | The second problem, which is potentially even worse, is that when the Deferred | ||||||
|  | returned by ``background_operation`` completes, it will restore the original | ||||||
|  | logcontext. There is nothing waiting on that Deferred, so the logcontext will | ||||||
|  | leak into the reactor and possibly get attached to some arbitrary future | ||||||
|  | operation. | ||||||
|  | 
 | ||||||
|  | There are two potential solutions to this. | ||||||
|  | 
 | ||||||
|  | One option is to surround the call to ``background_operation`` with a | ||||||
|  | ``PreserveLoggingContext`` call. That will reset the logcontext before | ||||||
|  | starting ``background_operation`` (so the context restored when the deferred | ||||||
|  | completes will be the empty logcontext), and will restore the current | ||||||
|  | logcontext before continuing the foreground process: | ||||||
|  | 
 | ||||||
|  | .. code:: python | ||||||
|  | 
 | ||||||
|  |     @defer.inlineCallbacks | ||||||
|  |     def do_request_handling(): | ||||||
|  |         yield foreground_operation() | ||||||
|  | 
 | ||||||
|  |         # start background_operation off in the empty logcontext, to | ||||||
|  |         # avoid leaking the current context into the reactor. | ||||||
|  |         with PreserveLoggingContext(): | ||||||
|  |             background_operation() | ||||||
|  | 
 | ||||||
|  |         # this will now be logged against the request context | ||||||
|  |         logger.debug("Request handling complete") | ||||||
|  | 
 | ||||||
|  | Obviously that option means that the operations done in | ||||||
|  | ``background_operation`` would be not be logged against a logcontext (though | ||||||
|  | that might be fixed by setting a different logcontext via a ``with | ||||||
|  | LoggingContext(...)`` in ``background_operation``). | ||||||
|  | 
 | ||||||
|  | The second option is to use ``logcontext.preserve_fn``, which wraps a function | ||||||
|  | so that it doesn't reset the logcontext even when it returns an incomplete | ||||||
|  | deferred, and adds a callback to the returned deferred to reset the | ||||||
|  | logcontext. In other words, it turns a function that follows the Synapse rules | ||||||
|  | about logcontexts and Deferreds into one which behaves more like an external | ||||||
|  | function — the opposite operation to that described in the previous section. | ||||||
|  | It can be used like this: | ||||||
|  | 
 | ||||||
|  | .. code:: python | ||||||
|  | 
 | ||||||
|  |     @defer.inlineCallbacks | ||||||
|  |     def do_request_handling(): | ||||||
|  |         yield foreground_operation() | ||||||
|  | 
 | ||||||
|  |         logcontext.preserve_fn(background_operation)() | ||||||
|  | 
 | ||||||
|  |         # this will now be logged against the request context | ||||||
|  |         logger.debug("Request handling complete") | ||||||
|  | 
 | ||||||
|  | XXX: I think ``preserve_context_over_fn`` is supposed to do the first option, | ||||||
|  | but the fact that it does ``preserve_context_over_deferred`` on its results | ||||||
|  | means that its use is fraught with difficulty. | ||||||
|  | 
 | ||||||
|  | Passing synapse deferreds into third-party functions | ||||||
|  | ---------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | A typical example of this is where we want to collect together two or more | ||||||
|  | deferred via ``defer.gatherResults``: | ||||||
|  | 
 | ||||||
|  | .. code:: python | ||||||
|  | 
 | ||||||
|  |     d1 = operation1() | ||||||
|  |     d2 = operation2() | ||||||
|  |     d3 = defer.gatherResults([d1, d2]) | ||||||
|  | 
 | ||||||
|  | This is really a variation of the fire-and-forget problem above, in that we are | ||||||
|  | firing off ``d1`` and ``d2`` without yielding on them. The difference | ||||||
|  | is that we now have third-party code attached to their callbacks. Anyway either | ||||||
|  | technique given in the `Fire-and-forget`_ section will work. | ||||||
|  | 
 | ||||||
|  | Of course, the new Deferred returned by ``gatherResults`` needs to be wrapped | ||||||
|  | in order to make it follow the logcontext rules before we can yield it, as | ||||||
|  | described in `Where you create a new Deferred, make it follow the rules`_. | ||||||
|  | 
 | ||||||
|  | So, option one: reset the logcontext before starting the operations to be | ||||||
|  | gathered: | ||||||
|  | 
 | ||||||
|  | .. code:: python | ||||||
|  | 
 | ||||||
|  |     @defer.inlineCallbacks | ||||||
|  |     def do_request_handling(): | ||||||
|  |         with PreserveLoggingContext(): | ||||||
|  |             d1 = operation1() | ||||||
|  |             d2 = operation2() | ||||||
|  |             result = yield defer.gatherResults([d1, d2]) | ||||||
|  | 
 | ||||||
|  | In this case particularly, though, option two, of using | ||||||
|  | ``logcontext.preserve_fn`` almost certainly makes more sense, so that | ||||||
|  | ``operation1`` and ``operation2`` are both logged against the original | ||||||
|  | logcontext. This looks like: | ||||||
|  | 
 | ||||||
|  | .. code:: python | ||||||
|  | 
 | ||||||
|  |     @defer.inlineCallbacks | ||||||
|  |     def do_request_handling(): | ||||||
|  |         d1 = logcontext.preserve_fn(operation1)() | ||||||
|  |         d2 = logcontext.preserve_fn(operation2)() | ||||||
|  | 
 | ||||||
|  |         with PreserveLoggingContext(): | ||||||
|  |             result = yield defer.gatherResults([d1, d2]) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | Was all this really necessary? | ||||||
|  | ------------------------------ | ||||||
|  | 
 | ||||||
|  | The conventions used work fine for a linear flow where everything happens in | ||||||
|  | series via ``defer.inlineCallbacks`` and ``yield``, but are certainly tricky to | ||||||
|  | follow for any more exotic flows. It's hard not to wonder if we could have done | ||||||
|  | something else. | ||||||
|  | 
 | ||||||
|  | We're not going to rewrite Synapse now, so the following is entirely of | ||||||
|  | academic interest, but I'd like to record some thoughts on an alternative | ||||||
|  | approach. | ||||||
|  | 
 | ||||||
|  | I briefly prototyped some code following an alternative set of rules. I think | ||||||
|  | it would work, but I certainly didn't get as far as thinking how it would | ||||||
|  | interact with concepts as complicated as the cache descriptors. | ||||||
|  | 
 | ||||||
|  | My alternative rules were: | ||||||
|  | 
 | ||||||
|  | * functions always preserve the logcontext of their caller, whether or not they | ||||||
|  |   are returning a Deferred. | ||||||
|  | 
 | ||||||
|  | * Deferreds returned by synapse functions run their callbacks in the same | ||||||
|  |   context as the function was orignally called in. | ||||||
|  | 
 | ||||||
|  | The main point of this scheme is that everywhere that sets the logcontext is | ||||||
|  | responsible for clearing it before returning control to the reactor. | ||||||
|  | 
 | ||||||
|  | So, for example, if you were the function which started a ``with | ||||||
|  | LoggingContext`` block, you wouldn't ``yield`` within it — instead you'd start | ||||||
|  | off the background process, and then leave the ``with`` block to wait for it: | ||||||
|  | 
 | ||||||
|  | .. code:: python | ||||||
|  | 
 | ||||||
|  |     def handle_request(request_id): | ||||||
|  |         with logcontext.LoggingContext() as request_context: | ||||||
|  |             request_context.request = request_id | ||||||
|  |             d = do_request_handling() | ||||||
|  | 
 | ||||||
|  |         def cb(r): | ||||||
|  |             logger.debug("finished") | ||||||
|  | 
 | ||||||
|  |         d.addCallback(cb) | ||||||
|  |         return d | ||||||
|  | 
 | ||||||
|  | (in general, mixing ``with LoggingContext`` blocks and | ||||||
|  | ``defer.inlineCallbacks`` in the same function leads to slighly | ||||||
|  | counter-intuitive code, under this scheme). | ||||||
|  | 
 | ||||||
|  | Because we leave the original ``with`` block as soon as the Deferred is | ||||||
|  | returned (as opposed to waiting for it to be resolved, as we do today), the | ||||||
|  | logcontext is cleared before control passes back to the reactor; so if there is | ||||||
|  | some code within ``do_request_handling`` which needs to wait for a Deferred to | ||||||
|  | complete, there is no need for it to worry about clearing the logcontext before | ||||||
|  | doing so: | ||||||
|  | 
 | ||||||
|  | .. code:: python | ||||||
|  | 
 | ||||||
|  |     def handle_request(): | ||||||
|  |         r = do_some_stuff() | ||||||
|  |         r.addCallback(do_some_more_stuff) | ||||||
|  |         return r | ||||||
|  | 
 | ||||||
|  | — and provided ``do_some_stuff`` follows the rules of returning a Deferred which | ||||||
|  | runs its callbacks in the original logcontext, all is happy. | ||||||
|  | 
 | ||||||
|  | The business of a Deferred which runs its callbacks in the original logcontext | ||||||
|  | isn't hard to achieve — we have it today, in the shape of | ||||||
|  | ``logcontext._PreservingContextDeferred``: | ||||||
|  | 
 | ||||||
|  | .. code:: python | ||||||
|  | 
 | ||||||
|  |     def do_some_stuff(): | ||||||
|  |         deferred = do_some_io() | ||||||
|  |         pcd = _PreservingContextDeferred(LoggingContext.current_context()) | ||||||
|  |         deferred.chainDeferred(pcd) | ||||||
|  |         return pcd | ||||||
|  | 
 | ||||||
|  | It turns out that, thanks to the way that Deferreds chain together, we | ||||||
|  | automatically get the property of a context-preserving deferred with | ||||||
|  | ``defer.inlineCallbacks``, provided the final Defered the function ``yields`` | ||||||
|  | on has that property. So we can just write: | ||||||
|  | 
 | ||||||
|  | .. code:: python | ||||||
|  | 
 | ||||||
|  |     @defer.inlineCallbacks | ||||||
|  |     def handle_request(): | ||||||
|  |         yield do_some_stuff() | ||||||
|  |         yield do_some_more_stuff() | ||||||
|  | 
 | ||||||
|  | To conclude: I think this scheme would have worked equally well, with less | ||||||
|  | danger of messing it up, and probably made some more esoteric code easier to | ||||||
|  | write. But again — changing the conventions of the entire Synapse codebase is | ||||||
|  | not a sensible option for the marginal improvement offered. | ||||||
|  | |||||||
| @ -16,4 +16,4 @@ | |||||||
| """ This is a reference implementation of a Matrix home server. | """ This is a reference implementation of a Matrix home server. | ||||||
| """ | """ | ||||||
| 
 | 
 | ||||||
| __version__ = "0.19.3" | __version__ = "0.20.0" | ||||||
|  | |||||||
| @ -23,7 +23,7 @@ from synapse import event_auth | |||||||
| from synapse.api.constants import EventTypes, Membership, JoinRules | from synapse.api.constants import EventTypes, Membership, JoinRules | ||||||
| from synapse.api.errors import AuthError, Codes | from synapse.api.errors import AuthError, Codes | ||||||
| from synapse.types import UserID | from synapse.types import UserID | ||||||
| from synapse.util.logcontext import preserve_context_over_fn | from synapse.util import logcontext | ||||||
| from synapse.util.metrics import Measure | from synapse.util.metrics import Measure | ||||||
| 
 | 
 | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
| @ -209,8 +209,7 @@ class Auth(object): | |||||||
|                 default=[""] |                 default=[""] | ||||||
|             )[0] |             )[0] | ||||||
|             if user and access_token and ip_addr: |             if user and access_token and ip_addr: | ||||||
|                 preserve_context_over_fn( |                 logcontext.preserve_fn(self.store.insert_client_ip)( | ||||||
|                     self.store.insert_client_ip, |  | ||||||
|                     user=user, |                     user=user, | ||||||
|                     access_token=access_token, |                     access_token=access_token, | ||||||
|                     ip=ip_addr, |                     ip=ip_addr, | ||||||
|  | |||||||
| @ -1,5 +1,6 @@ | |||||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||||
| # Copyright 2014-2016 OpenMarket Ltd | # Copyright 2014-2016 OpenMarket Ltd | ||||||
|  | # Copyright 2017 Vector Creations Ltd | ||||||
| # | # | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||||
| @ -44,6 +45,7 @@ class JoinRules(object): | |||||||
| class LoginType(object): | class LoginType(object): | ||||||
|     PASSWORD = u"m.login.password" |     PASSWORD = u"m.login.password" | ||||||
|     EMAIL_IDENTITY = u"m.login.email.identity" |     EMAIL_IDENTITY = u"m.login.email.identity" | ||||||
|  |     MSISDN = u"m.login.msisdn" | ||||||
|     RECAPTCHA = u"m.login.recaptcha" |     RECAPTCHA = u"m.login.recaptcha" | ||||||
|     DUMMY = u"m.login.dummy" |     DUMMY = u"m.login.dummy" | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -15,6 +15,7 @@ | |||||||
| 
 | 
 | ||||||
| """Contains exceptions and error codes.""" | """Contains exceptions and error codes.""" | ||||||
| 
 | 
 | ||||||
|  | import json | ||||||
| import logging | import logging | ||||||
| 
 | 
 | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
| @ -50,27 +51,35 @@ class Codes(object): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class CodeMessageException(RuntimeError): | class CodeMessageException(RuntimeError): | ||||||
|     """An exception with integer code and message string attributes.""" |     """An exception with integer code and message string attributes. | ||||||
| 
 | 
 | ||||||
|  |     Attributes: | ||||||
|  |         code (int): HTTP error code | ||||||
|  |         msg (str): string describing the error | ||||||
|  |     """ | ||||||
|     def __init__(self, code, msg): |     def __init__(self, code, msg): | ||||||
|         super(CodeMessageException, self).__init__("%d: %s" % (code, msg)) |         super(CodeMessageException, self).__init__("%d: %s" % (code, msg)) | ||||||
|         self.code = code |         self.code = code | ||||||
|         self.msg = msg |         self.msg = msg | ||||||
|         self.response_code_message = None |  | ||||||
| 
 | 
 | ||||||
|     def error_dict(self): |     def error_dict(self): | ||||||
|         return cs_error(self.msg) |         return cs_error(self.msg) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class SynapseError(CodeMessageException): | class SynapseError(CodeMessageException): | ||||||
|     """A base error which can be caught for all synapse events.""" |     """A base exception type for matrix errors which have an errcode and error | ||||||
|  |     message (as well as an HTTP status code). | ||||||
|  | 
 | ||||||
|  |     Attributes: | ||||||
|  |         errcode (str): Matrix error code e.g 'M_FORBIDDEN' | ||||||
|  |     """ | ||||||
|     def __init__(self, code, msg, errcode=Codes.UNKNOWN): |     def __init__(self, code, msg, errcode=Codes.UNKNOWN): | ||||||
|         """Constructs a synapse error. |         """Constructs a synapse error. | ||||||
| 
 | 
 | ||||||
|         Args: |         Args: | ||||||
|             code (int): The integer error code (an HTTP response code) |             code (int): The integer error code (an HTTP response code) | ||||||
|             msg (str): The human-readable error message. |             msg (str): The human-readable error message. | ||||||
|             err (str): The error code e.g 'M_FORBIDDEN' |             errcode (str): The matrix error code e.g 'M_FORBIDDEN' | ||||||
|         """ |         """ | ||||||
|         super(SynapseError, self).__init__(code, msg) |         super(SynapseError, self).__init__(code, msg) | ||||||
|         self.errcode = errcode |         self.errcode = errcode | ||||||
| @ -81,6 +90,39 @@ class SynapseError(CodeMessageException): | |||||||
|             self.errcode, |             self.errcode, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def from_http_response_exception(cls, err): | ||||||
|  |         """Make a SynapseError based on an HTTPResponseException | ||||||
|  | 
 | ||||||
|  |         This is useful when a proxied request has failed, and we need to | ||||||
|  |         decide how to map the failure onto a matrix error to send back to the | ||||||
|  |         client. | ||||||
|  | 
 | ||||||
|  |         An attempt is made to parse the body of the http response as a matrix | ||||||
|  |         error. If that succeeds, the errcode and error message from the body | ||||||
|  |         are used as the errcode and error message in the new synapse error. | ||||||
|  | 
 | ||||||
|  |         Otherwise, the errcode is set to M_UNKNOWN, and the error message is | ||||||
|  |         set to the reason code from the HTTP response. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             err (HttpResponseException): | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             SynapseError: | ||||||
|  |         """ | ||||||
|  |         # try to parse the body as json, to get better errcode/msg, but | ||||||
|  |         # default to M_UNKNOWN with the HTTP status as the error text | ||||||
|  |         try: | ||||||
|  |             j = json.loads(err.response) | ||||||
|  |         except ValueError: | ||||||
|  |             j = {} | ||||||
|  |         errcode = j.get('errcode', Codes.UNKNOWN) | ||||||
|  |         errmsg = j.get('error', err.msg) | ||||||
|  | 
 | ||||||
|  |         res = SynapseError(err.code, errmsg, errcode) | ||||||
|  |         return res | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| class RegistrationError(SynapseError): | class RegistrationError(SynapseError): | ||||||
|     """An error raised when a registration event fails.""" |     """An error raised when a registration event fails.""" | ||||||
| @ -106,13 +148,11 @@ class UnrecognizedRequestError(SynapseError): | |||||||
| 
 | 
 | ||||||
| class NotFoundError(SynapseError): | class NotFoundError(SynapseError): | ||||||
|     """An error indicating we can't find the thing you asked for""" |     """An error indicating we can't find the thing you asked for""" | ||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, msg="Not found", errcode=Codes.NOT_FOUND): | ||||||
|         if "errcode" not in kwargs: |  | ||||||
|             kwargs["errcode"] = Codes.NOT_FOUND |  | ||||||
|         super(NotFoundError, self).__init__( |         super(NotFoundError, self).__init__( | ||||||
|             404, |             404, | ||||||
|             "Not found", |             msg, | ||||||
|             **kwargs |             errcode=errcode | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -173,7 +213,6 @@ class LimitExceededError(SynapseError): | |||||||
|                  errcode=Codes.LIMIT_EXCEEDED): |                  errcode=Codes.LIMIT_EXCEEDED): | ||||||
|         super(LimitExceededError, self).__init__(code, msg, errcode) |         super(LimitExceededError, self).__init__(code, msg, errcode) | ||||||
|         self.retry_after_ms = retry_after_ms |         self.retry_after_ms = retry_after_ms | ||||||
|         self.response_code_message = "Too Many Requests" |  | ||||||
| 
 | 
 | ||||||
|     def error_dict(self): |     def error_dict(self): | ||||||
|         return cs_error( |         return cs_error( | ||||||
| @ -243,6 +282,19 @@ class FederationError(RuntimeError): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class HttpResponseException(CodeMessageException): | class HttpResponseException(CodeMessageException): | ||||||
|  |     """ | ||||||
|  |     Represents an HTTP-level failure of an outbound request | ||||||
|  | 
 | ||||||
|  |     Attributes: | ||||||
|  |         response (str): body of response | ||||||
|  |     """ | ||||||
|     def __init__(self, code, msg, response): |     def __init__(self, code, msg, response): | ||||||
|         self.response = response |         """ | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             code (int): HTTP status code | ||||||
|  |             msg (str): reason phrase from HTTP response status line | ||||||
|  |             response (str): body of response | ||||||
|  |         """ | ||||||
|         super(HttpResponseException, self).__init__(code, msg) |         super(HttpResponseException, self).__init__(code, msg) | ||||||
|  |         self.response = response | ||||||
|  | |||||||
| @ -13,11 +13,174 @@ | |||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| from synapse.api.errors import SynapseError | from synapse.api.errors import SynapseError | ||||||
|  | from synapse.storage.presence import UserPresenceState | ||||||
| from synapse.types import UserID, RoomID | from synapse.types import UserID, RoomID | ||||||
| 
 |  | ||||||
| from twisted.internet import defer | from twisted.internet import defer | ||||||
| 
 | 
 | ||||||
| import ujson as json | import ujson as json | ||||||
|  | import jsonschema | ||||||
|  | from jsonschema import FormatChecker | ||||||
|  | 
 | ||||||
|  | FILTER_SCHEMA = { | ||||||
|  |     "additionalProperties": False, | ||||||
|  |     "type": "object", | ||||||
|  |     "properties": { | ||||||
|  |         "limit": { | ||||||
|  |             "type": "number" | ||||||
|  |         }, | ||||||
|  |         "senders": { | ||||||
|  |             "$ref": "#/definitions/user_id_array" | ||||||
|  |         }, | ||||||
|  |         "not_senders": { | ||||||
|  |             "$ref": "#/definitions/user_id_array" | ||||||
|  |         }, | ||||||
|  |         # TODO: We don't limit event type values but we probably should... | ||||||
|  |         # check types are valid event types | ||||||
|  |         "types": { | ||||||
|  |             "type": "array", | ||||||
|  |             "items": { | ||||||
|  |                 "type": "string" | ||||||
|  |             } | ||||||
|  |         }, | ||||||
|  |         "not_types": { | ||||||
|  |             "type": "array", | ||||||
|  |             "items": { | ||||||
|  |                 "type": "string" | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | ROOM_FILTER_SCHEMA = { | ||||||
|  |     "additionalProperties": False, | ||||||
|  |     "type": "object", | ||||||
|  |     "properties": { | ||||||
|  |         "not_rooms": { | ||||||
|  |             "$ref": "#/definitions/room_id_array" | ||||||
|  |         }, | ||||||
|  |         "rooms": { | ||||||
|  |             "$ref": "#/definitions/room_id_array" | ||||||
|  |         }, | ||||||
|  |         "ephemeral": { | ||||||
|  |             "$ref": "#/definitions/room_event_filter" | ||||||
|  |         }, | ||||||
|  |         "include_leave": { | ||||||
|  |             "type": "boolean" | ||||||
|  |         }, | ||||||
|  |         "state": { | ||||||
|  |             "$ref": "#/definitions/room_event_filter" | ||||||
|  |         }, | ||||||
|  |         "timeline": { | ||||||
|  |             "$ref": "#/definitions/room_event_filter" | ||||||
|  |         }, | ||||||
|  |         "account_data": { | ||||||
|  |             "$ref": "#/definitions/room_event_filter" | ||||||
|  |         }, | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | ROOM_EVENT_FILTER_SCHEMA = { | ||||||
|  |     "additionalProperties": False, | ||||||
|  |     "type": "object", | ||||||
|  |     "properties": { | ||||||
|  |         "limit": { | ||||||
|  |             "type": "number" | ||||||
|  |         }, | ||||||
|  |         "senders": { | ||||||
|  |             "$ref": "#/definitions/user_id_array" | ||||||
|  |         }, | ||||||
|  |         "not_senders": { | ||||||
|  |             "$ref": "#/definitions/user_id_array" | ||||||
|  |         }, | ||||||
|  |         "types": { | ||||||
|  |             "type": "array", | ||||||
|  |             "items": { | ||||||
|  |                 "type": "string" | ||||||
|  |             } | ||||||
|  |         }, | ||||||
|  |         "not_types": { | ||||||
|  |             "type": "array", | ||||||
|  |             "items": { | ||||||
|  |                 "type": "string" | ||||||
|  |             } | ||||||
|  |         }, | ||||||
|  |         "rooms": { | ||||||
|  |             "$ref": "#/definitions/room_id_array" | ||||||
|  |         }, | ||||||
|  |         "not_rooms": { | ||||||
|  |             "$ref": "#/definitions/room_id_array" | ||||||
|  |         }, | ||||||
|  |         "contains_url": { | ||||||
|  |             "type": "boolean" | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | USER_ID_ARRAY_SCHEMA = { | ||||||
|  |     "type": "array", | ||||||
|  |     "items": { | ||||||
|  |         "type": "string", | ||||||
|  |         "format": "matrix_user_id" | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | ROOM_ID_ARRAY_SCHEMA = { | ||||||
|  |     "type": "array", | ||||||
|  |     "items": { | ||||||
|  |         "type": "string", | ||||||
|  |         "format": "matrix_room_id" | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | USER_FILTER_SCHEMA = { | ||||||
|  |     "$schema": "http://json-schema.org/draft-04/schema#", | ||||||
|  |     "description": "schema for a Sync filter", | ||||||
|  |     "type": "object", | ||||||
|  |     "definitions": { | ||||||
|  |         "room_id_array": ROOM_ID_ARRAY_SCHEMA, | ||||||
|  |         "user_id_array": USER_ID_ARRAY_SCHEMA, | ||||||
|  |         "filter": FILTER_SCHEMA, | ||||||
|  |         "room_filter": ROOM_FILTER_SCHEMA, | ||||||
|  |         "room_event_filter": ROOM_EVENT_FILTER_SCHEMA | ||||||
|  |     }, | ||||||
|  |     "properties": { | ||||||
|  |         "presence": { | ||||||
|  |             "$ref": "#/definitions/filter" | ||||||
|  |         }, | ||||||
|  |         "account_data": { | ||||||
|  |             "$ref": "#/definitions/filter" | ||||||
|  |         }, | ||||||
|  |         "room": { | ||||||
|  |             "$ref": "#/definitions/room_filter" | ||||||
|  |         }, | ||||||
|  |         "event_format": { | ||||||
|  |             "type": "string", | ||||||
|  |             "enum": ["client", "federation"] | ||||||
|  |         }, | ||||||
|  |         "event_fields": { | ||||||
|  |             "type": "array", | ||||||
|  |             "items": { | ||||||
|  |                 "type": "string", | ||||||
|  |                 # Don't allow '\\' in event field filters. This makes matching | ||||||
|  |                 # events a lot easier as we can then use a negative lookbehind | ||||||
|  |                 # assertion to split '\.' If we allowed \\ then it would | ||||||
|  |                 # incorrectly split '\\.' See synapse.events.utils.serialize_event | ||||||
|  |                 "pattern": "^((?!\\\).)*$" | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     }, | ||||||
|  |     "additionalProperties": False | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @FormatChecker.cls_checks('matrix_room_id') | ||||||
|  | def matrix_room_id_validator(room_id_str): | ||||||
|  |     return RoomID.from_string(room_id_str) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @FormatChecker.cls_checks('matrix_user_id') | ||||||
|  | def matrix_user_id_validator(user_id_str): | ||||||
|  |     return UserID.from_string(user_id_str) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class Filtering(object): | class Filtering(object): | ||||||
| @ -52,98 +215,11 @@ class Filtering(object): | |||||||
|         # NB: Filters are the complete json blobs. "Definitions" are an |         # NB: Filters are the complete json blobs. "Definitions" are an | ||||||
|         # individual top-level key e.g. public_user_data. Filters are made of |         # individual top-level key e.g. public_user_data. Filters are made of | ||||||
|         # many definitions. |         # many definitions. | ||||||
| 
 |         try: | ||||||
|         top_level_definitions = [ |             jsonschema.validate(user_filter_json, USER_FILTER_SCHEMA, | ||||||
|             "presence", "account_data" |                                 format_checker=FormatChecker()) | ||||||
|         ] |         except jsonschema.ValidationError as e: | ||||||
| 
 |             raise SynapseError(400, e.message) | ||||||
|         room_level_definitions = [ |  | ||||||
|             "state", "timeline", "ephemeral", "account_data" |  | ||||||
|         ] |  | ||||||
| 
 |  | ||||||
|         for key in top_level_definitions: |  | ||||||
|             if key in user_filter_json: |  | ||||||
|                 self._check_definition(user_filter_json[key]) |  | ||||||
| 
 |  | ||||||
|         if "room" in user_filter_json: |  | ||||||
|             self._check_definition_room_lists(user_filter_json["room"]) |  | ||||||
|             for key in room_level_definitions: |  | ||||||
|                 if key in user_filter_json["room"]: |  | ||||||
|                     self._check_definition(user_filter_json["room"][key]) |  | ||||||
| 
 |  | ||||||
|         if "event_fields" in user_filter_json: |  | ||||||
|             if type(user_filter_json["event_fields"]) != list: |  | ||||||
|                 raise SynapseError(400, "event_fields must be a list of strings") |  | ||||||
|             for field in user_filter_json["event_fields"]: |  | ||||||
|                 if not isinstance(field, basestring): |  | ||||||
|                     raise SynapseError(400, "Event field must be a string") |  | ||||||
|                 # Don't allow '\\' in event field filters. This makes matching |  | ||||||
|                 # events a lot easier as we can then use a negative lookbehind |  | ||||||
|                 # assertion to split '\.' If we allowed \\ then it would |  | ||||||
|                 # incorrectly split '\\.' See synapse.events.utils.serialize_event |  | ||||||
|                 if r'\\' in field: |  | ||||||
|                     raise SynapseError( |  | ||||||
|                         400, r'The escape character \ cannot itself be escaped' |  | ||||||
|                     ) |  | ||||||
| 
 |  | ||||||
|     def _check_definition_room_lists(self, definition): |  | ||||||
|         """Check that "rooms" and "not_rooms" are lists of room ids if they |  | ||||||
|         are present |  | ||||||
| 
 |  | ||||||
|         Args: |  | ||||||
|             definition(dict): The filter definition |  | ||||||
|         Raises: |  | ||||||
|             SynapseError: If there was a problem with this definition. |  | ||||||
|         """ |  | ||||||
|         # check rooms are valid room IDs |  | ||||||
|         room_id_keys = ["rooms", "not_rooms"] |  | ||||||
|         for key in room_id_keys: |  | ||||||
|             if key in definition: |  | ||||||
|                 if type(definition[key]) != list: |  | ||||||
|                     raise SynapseError(400, "Expected %s to be a list." % key) |  | ||||||
|                 for room_id in definition[key]: |  | ||||||
|                     RoomID.from_string(room_id) |  | ||||||
| 
 |  | ||||||
|     def _check_definition(self, definition): |  | ||||||
|         """Check if the provided definition is valid. |  | ||||||
| 
 |  | ||||||
|         This inspects not only the types but also the values to make sure they |  | ||||||
|         make sense. |  | ||||||
| 
 |  | ||||||
|         Args: |  | ||||||
|             definition(dict): The filter definition |  | ||||||
|         Raises: |  | ||||||
|             SynapseError: If there was a problem with this definition. |  | ||||||
|         """ |  | ||||||
|         # NB: Filters are the complete json blobs. "Definitions" are an |  | ||||||
|         # individual top-level key e.g. public_user_data. Filters are made of |  | ||||||
|         # many definitions. |  | ||||||
|         if type(definition) != dict: |  | ||||||
|             raise SynapseError( |  | ||||||
|                 400, "Expected JSON object, not %s" % (definition,) |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|         self._check_definition_room_lists(definition) |  | ||||||
| 
 |  | ||||||
|         # check senders are valid user IDs |  | ||||||
|         user_id_keys = ["senders", "not_senders"] |  | ||||||
|         for key in user_id_keys: |  | ||||||
|             if key in definition: |  | ||||||
|                 if type(definition[key]) != list: |  | ||||||
|                     raise SynapseError(400, "Expected %s to be a list." % key) |  | ||||||
|                 for user_id in definition[key]: |  | ||||||
|                     UserID.from_string(user_id) |  | ||||||
| 
 |  | ||||||
|         # TODO: We don't limit event type values but we probably should... |  | ||||||
|         # check types are valid event types |  | ||||||
|         event_keys = ["types", "not_types"] |  | ||||||
|         for key in event_keys: |  | ||||||
|             if key in definition: |  | ||||||
|                 if type(definition[key]) != list: |  | ||||||
|                     raise SynapseError(400, "Expected %s to be a list." % key) |  | ||||||
|                 for event_type in definition[key]: |  | ||||||
|                     if not isinstance(event_type, basestring): |  | ||||||
|                         raise SynapseError(400, "Event type should be a string") |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class FilterCollection(object): | class FilterCollection(object): | ||||||
| @ -253,19 +329,35 @@ class Filter(object): | |||||||
|         Returns: |         Returns: | ||||||
|             bool: True if the event matches |             bool: True if the event matches | ||||||
|         """ |         """ | ||||||
|  |         # We usually get the full "events" as dictionaries coming through, | ||||||
|  |         # except for presence which actually gets passed around as its own | ||||||
|  |         # namedtuple type. | ||||||
|  |         if isinstance(event, UserPresenceState): | ||||||
|  |             sender = event.user_id | ||||||
|  |             room_id = None | ||||||
|  |             ev_type = "m.presence" | ||||||
|  |             is_url = False | ||||||
|  |         else: | ||||||
|             sender = event.get("sender", None) |             sender = event.get("sender", None) | ||||||
|             if not sender: |             if not sender: | ||||||
|             # Presence events have their 'sender' in content.user_id |                 # Presence events had their 'sender' in content.user_id, but are | ||||||
|  |                 # now handled above. We don't know if anything else uses this | ||||||
|  |                 # form. TODO: Check this and probably remove it. | ||||||
|                 content = event.get("content") |                 content = event.get("content") | ||||||
|             # account_data has been allowed to have non-dict content, so check type first |                 # account_data has been allowed to have non-dict content, so | ||||||
|  |                 # check type first | ||||||
|                 if isinstance(content, dict): |                 if isinstance(content, dict): | ||||||
|                     sender = content.get("user_id") |                     sender = content.get("user_id") | ||||||
| 
 | 
 | ||||||
|  |             room_id = event.get("room_id", None) | ||||||
|  |             ev_type = event.get("type", None) | ||||||
|  |             is_url = "url" in event.get("content", {}) | ||||||
|  | 
 | ||||||
|         return self.check_fields( |         return self.check_fields( | ||||||
|             event.get("room_id", None), |             room_id, | ||||||
|             sender, |             sender, | ||||||
|             event.get("type", None), |             ev_type, | ||||||
|             "url" in event.get("content", {}) |             is_url, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     def check_fields(self, room_id, sender, event_type, contains_url): |     def check_fields(self, room_id, sender, event_type, contains_url): | ||||||
|  | |||||||
| @ -29,7 +29,7 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto | |||||||
| from synapse.storage.engines import create_engine | from synapse.storage.engines import create_engine | ||||||
| from synapse.util.async import sleep | from synapse.util.async import sleep | ||||||
| from synapse.util.httpresourcetree import create_resource_tree | from synapse.util.httpresourcetree import create_resource_tree | ||||||
| from synapse.util.logcontext import LoggingContext | from synapse.util.logcontext import LoggingContext, PreserveLoggingContext | ||||||
| from synapse.util.manhole import manhole | from synapse.util.manhole import manhole | ||||||
| from synapse.util.rlimit import change_resource_limit | from synapse.util.rlimit import change_resource_limit | ||||||
| from synapse.util.versionstring import get_version_string | from synapse.util.versionstring import get_version_string | ||||||
| @ -157,7 +157,7 @@ def start(config_options): | |||||||
| 
 | 
 | ||||||
|     assert config.worker_app == "synapse.app.appservice" |     assert config.worker_app == "synapse.app.appservice" | ||||||
| 
 | 
 | ||||||
|     setup_logging(config.worker_log_config, config.worker_log_file) |     setup_logging(config, use_worker_options=True) | ||||||
| 
 | 
 | ||||||
|     events.USE_FROZEN_DICTS = config.use_frozen_dicts |     events.USE_FROZEN_DICTS = config.use_frozen_dicts | ||||||
| 
 | 
 | ||||||
| @ -187,7 +187,11 @@ def start(config_options): | |||||||
|     ps.start_listening(config.worker_listeners) |     ps.start_listening(config.worker_listeners) | ||||||
| 
 | 
 | ||||||
|     def run(): |     def run(): | ||||||
|         with LoggingContext("run"): |         # make sure that we run the reactor with the sentinel log context, | ||||||
|  |         # otherwise other PreserveLoggingContext instances will get confused | ||||||
|  |         # and complain when they see the logcontext arbitrarily swapping | ||||||
|  |         # between the sentinel and `run` logcontexts. | ||||||
|  |         with PreserveLoggingContext(): | ||||||
|             logger.info("Running") |             logger.info("Running") | ||||||
|             change_resource_limit(config.soft_file_limit) |             change_resource_limit(config.soft_file_limit) | ||||||
|             if config.gc_thresholds: |             if config.gc_thresholds: | ||||||
|  | |||||||
| @ -29,13 +29,14 @@ from synapse.replication.slave.storage.keys import SlavedKeyStore | |||||||
| from synapse.replication.slave.storage.room import RoomStore | from synapse.replication.slave.storage.room import RoomStore | ||||||
| from synapse.replication.slave.storage.directory import DirectoryStore | from synapse.replication.slave.storage.directory import DirectoryStore | ||||||
| from synapse.replication.slave.storage.registration import SlavedRegistrationStore | from synapse.replication.slave.storage.registration import SlavedRegistrationStore | ||||||
|  | from synapse.replication.slave.storage.transactions import TransactionStore | ||||||
| from synapse.rest.client.v1.room import PublicRoomListRestServlet | from synapse.rest.client.v1.room import PublicRoomListRestServlet | ||||||
| from synapse.server import HomeServer | from synapse.server import HomeServer | ||||||
| from synapse.storage.client_ips import ClientIpStore | from synapse.storage.client_ips import ClientIpStore | ||||||
| from synapse.storage.engines import create_engine | from synapse.storage.engines import create_engine | ||||||
| from synapse.util.async import sleep | from synapse.util.async import sleep | ||||||
| from synapse.util.httpresourcetree import create_resource_tree | from synapse.util.httpresourcetree import create_resource_tree | ||||||
| from synapse.util.logcontext import LoggingContext | from synapse.util.logcontext import LoggingContext, PreserveLoggingContext | ||||||
| from synapse.util.manhole import manhole | from synapse.util.manhole import manhole | ||||||
| from synapse.util.rlimit import change_resource_limit | from synapse.util.rlimit import change_resource_limit | ||||||
| from synapse.util.versionstring import get_version_string | from synapse.util.versionstring import get_version_string | ||||||
| @ -63,6 +64,7 @@ class ClientReaderSlavedStore( | |||||||
|     DirectoryStore, |     DirectoryStore, | ||||||
|     SlavedApplicationServiceStore, |     SlavedApplicationServiceStore, | ||||||
|     SlavedRegistrationStore, |     SlavedRegistrationStore, | ||||||
|  |     TransactionStore, | ||||||
|     BaseSlavedStore, |     BaseSlavedStore, | ||||||
|     ClientIpStore,  # After BaseSlavedStore because the constructor is different |     ClientIpStore,  # After BaseSlavedStore because the constructor is different | ||||||
| ): | ): | ||||||
| @ -171,7 +173,7 @@ def start(config_options): | |||||||
| 
 | 
 | ||||||
|     assert config.worker_app == "synapse.app.client_reader" |     assert config.worker_app == "synapse.app.client_reader" | ||||||
| 
 | 
 | ||||||
|     setup_logging(config.worker_log_config, config.worker_log_file) |     setup_logging(config, use_worker_options=True) | ||||||
| 
 | 
 | ||||||
|     events.USE_FROZEN_DICTS = config.use_frozen_dicts |     events.USE_FROZEN_DICTS = config.use_frozen_dicts | ||||||
| 
 | 
 | ||||||
| @ -193,7 +195,11 @@ def start(config_options): | |||||||
|     ss.start_listening(config.worker_listeners) |     ss.start_listening(config.worker_listeners) | ||||||
| 
 | 
 | ||||||
|     def run(): |     def run(): | ||||||
|         with LoggingContext("run"): |         # make sure that we run the reactor with the sentinel log context, | ||||||
|  |         # otherwise other PreserveLoggingContext instances will get confused | ||||||
|  |         # and complain when they see the logcontext arbitrarily swapping | ||||||
|  |         # between the sentinel and `run` logcontexts. | ||||||
|  |         with PreserveLoggingContext(): | ||||||
|             logger.info("Running") |             logger.info("Running") | ||||||
|             change_resource_limit(config.soft_file_limit) |             change_resource_limit(config.soft_file_limit) | ||||||
|             if config.gc_thresholds: |             if config.gc_thresholds: | ||||||
|  | |||||||
| @ -31,7 +31,7 @@ from synapse.server import HomeServer | |||||||
| from synapse.storage.engines import create_engine | from synapse.storage.engines import create_engine | ||||||
| from synapse.util.async import sleep | from synapse.util.async import sleep | ||||||
| from synapse.util.httpresourcetree import create_resource_tree | from synapse.util.httpresourcetree import create_resource_tree | ||||||
| from synapse.util.logcontext import LoggingContext | from synapse.util.logcontext import LoggingContext, PreserveLoggingContext | ||||||
| from synapse.util.manhole import manhole | from synapse.util.manhole import manhole | ||||||
| from synapse.util.rlimit import change_resource_limit | from synapse.util.rlimit import change_resource_limit | ||||||
| from synapse.util.versionstring import get_version_string | from synapse.util.versionstring import get_version_string | ||||||
| @ -162,7 +162,7 @@ def start(config_options): | |||||||
| 
 | 
 | ||||||
|     assert config.worker_app == "synapse.app.federation_reader" |     assert config.worker_app == "synapse.app.federation_reader" | ||||||
| 
 | 
 | ||||||
|     setup_logging(config.worker_log_config, config.worker_log_file) |     setup_logging(config, use_worker_options=True) | ||||||
| 
 | 
 | ||||||
|     events.USE_FROZEN_DICTS = config.use_frozen_dicts |     events.USE_FROZEN_DICTS = config.use_frozen_dicts | ||||||
| 
 | 
 | ||||||
| @ -184,7 +184,11 @@ def start(config_options): | |||||||
|     ss.start_listening(config.worker_listeners) |     ss.start_listening(config.worker_listeners) | ||||||
| 
 | 
 | ||||||
|     def run(): |     def run(): | ||||||
|         with LoggingContext("run"): |         # make sure that we run the reactor with the sentinel log context, | ||||||
|  |         # otherwise other PreserveLoggingContext instances will get confused | ||||||
|  |         # and complain when they see the logcontext arbitrarily swapping | ||||||
|  |         # between the sentinel and `run` logcontexts. | ||||||
|  |         with PreserveLoggingContext(): | ||||||
|             logger.info("Running") |             logger.info("Running") | ||||||
|             change_resource_limit(config.soft_file_limit) |             change_resource_limit(config.soft_file_limit) | ||||||
|             if config.gc_thresholds: |             if config.gc_thresholds: | ||||||
|  | |||||||
| @ -35,7 +35,7 @@ from synapse.storage.engines import create_engine | |||||||
| from synapse.storage.presence import UserPresenceState | from synapse.storage.presence import UserPresenceState | ||||||
| from synapse.util.async import sleep | from synapse.util.async import sleep | ||||||
| from synapse.util.httpresourcetree import create_resource_tree | from synapse.util.httpresourcetree import create_resource_tree | ||||||
| from synapse.util.logcontext import LoggingContext | from synapse.util.logcontext import LoggingContext, PreserveLoggingContext | ||||||
| from synapse.util.manhole import manhole | from synapse.util.manhole import manhole | ||||||
| from synapse.util.rlimit import change_resource_limit | from synapse.util.rlimit import change_resource_limit | ||||||
| from synapse.util.versionstring import get_version_string | from synapse.util.versionstring import get_version_string | ||||||
| @ -160,7 +160,7 @@ def start(config_options): | |||||||
| 
 | 
 | ||||||
|     assert config.worker_app == "synapse.app.federation_sender" |     assert config.worker_app == "synapse.app.federation_sender" | ||||||
| 
 | 
 | ||||||
|     setup_logging(config.worker_log_config, config.worker_log_file) |     setup_logging(config, use_worker_options=True) | ||||||
| 
 | 
 | ||||||
|     events.USE_FROZEN_DICTS = config.use_frozen_dicts |     events.USE_FROZEN_DICTS = config.use_frozen_dicts | ||||||
| 
 | 
 | ||||||
| @ -193,7 +193,11 @@ def start(config_options): | |||||||
|     ps.start_listening(config.worker_listeners) |     ps.start_listening(config.worker_listeners) | ||||||
| 
 | 
 | ||||||
|     def run(): |     def run(): | ||||||
|         with LoggingContext("run"): |         # make sure that we run the reactor with the sentinel log context, | ||||||
|  |         # otherwise other PreserveLoggingContext instances will get confused | ||||||
|  |         # and complain when they see the logcontext arbitrarily swapping | ||||||
|  |         # between the sentinel and `run` logcontexts. | ||||||
|  |         with PreserveLoggingContext(): | ||||||
|             logger.info("Running") |             logger.info("Running") | ||||||
|             change_resource_limit(config.soft_file_limit) |             change_resource_limit(config.soft_file_limit) | ||||||
|             if config.gc_thresholds: |             if config.gc_thresholds: | ||||||
|  | |||||||
| @ -20,6 +20,8 @@ import gc | |||||||
| import logging | import logging | ||||||
| import os | import os | ||||||
| import sys | import sys | ||||||
|  | 
 | ||||||
|  | import synapse.config.logger | ||||||
| from synapse.config._base import ConfigError | from synapse.config._base import ConfigError | ||||||
| 
 | 
 | ||||||
| from synapse.python_dependencies import ( | from synapse.python_dependencies import ( | ||||||
| @ -50,7 +52,7 @@ from synapse.api.urls import ( | |||||||
| ) | ) | ||||||
| from synapse.config.homeserver import HomeServerConfig | from synapse.config.homeserver import HomeServerConfig | ||||||
| from synapse.crypto import context_factory | from synapse.crypto import context_factory | ||||||
| from synapse.util.logcontext import LoggingContext | from synapse.util.logcontext import LoggingContext, PreserveLoggingContext | ||||||
| from synapse.metrics import register_memory_metrics, get_metrics_for | from synapse.metrics import register_memory_metrics, get_metrics_for | ||||||
| from synapse.metrics.resource import MetricsResource, METRICS_PREFIX | from synapse.metrics.resource import MetricsResource, METRICS_PREFIX | ||||||
| from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX | from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX | ||||||
| @ -286,7 +288,7 @@ def setup(config_options): | |||||||
|         # generating config files and shouldn't try to continue. |         # generating config files and shouldn't try to continue. | ||||||
|         sys.exit(0) |         sys.exit(0) | ||||||
| 
 | 
 | ||||||
|     config.setup_logging() |     synapse.config.logger.setup_logging(config, use_worker_options=False) | ||||||
| 
 | 
 | ||||||
|     # check any extra requirements we have now we have a config |     # check any extra requirements we have now we have a config | ||||||
|     check_requirements(config) |     check_requirements(config) | ||||||
| @ -454,7 +456,12 @@ def run(hs): | |||||||
|     def in_thread(): |     def in_thread(): | ||||||
|         # Uncomment to enable tracing of log context changes. |         # Uncomment to enable tracing of log context changes. | ||||||
|         # sys.settrace(logcontext_tracer) |         # sys.settrace(logcontext_tracer) | ||||||
|         with LoggingContext("run"): | 
 | ||||||
|  |         # make sure that we run the reactor with the sentinel log context, | ||||||
|  |         # otherwise other PreserveLoggingContext instances will get confused | ||||||
|  |         # and complain when they see the logcontext arbitrarily swapping | ||||||
|  |         # between the sentinel and `run` logcontexts. | ||||||
|  |         with PreserveLoggingContext(): | ||||||
|             change_resource_limit(hs.config.soft_file_limit) |             change_resource_limit(hs.config.soft_file_limit) | ||||||
|             if hs.config.gc_thresholds: |             if hs.config.gc_thresholds: | ||||||
|                 gc.set_threshold(*hs.config.gc_thresholds) |                 gc.set_threshold(*hs.config.gc_thresholds) | ||||||
|  | |||||||
| @ -24,6 +24,7 @@ from synapse.metrics.resource import MetricsResource, METRICS_PREFIX | |||||||
| from synapse.replication.slave.storage._base import BaseSlavedStore | from synapse.replication.slave.storage._base import BaseSlavedStore | ||||||
| from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore | from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore | ||||||
| from synapse.replication.slave.storage.registration import SlavedRegistrationStore | from synapse.replication.slave.storage.registration import SlavedRegistrationStore | ||||||
|  | from synapse.replication.slave.storage.transactions import TransactionStore | ||||||
| from synapse.rest.media.v0.content_repository import ContentRepoResource | from synapse.rest.media.v0.content_repository import ContentRepoResource | ||||||
| from synapse.rest.media.v1.media_repository import MediaRepositoryResource | from synapse.rest.media.v1.media_repository import MediaRepositoryResource | ||||||
| from synapse.server import HomeServer | from synapse.server import HomeServer | ||||||
| @ -32,7 +33,7 @@ from synapse.storage.engines import create_engine | |||||||
| from synapse.storage.media_repository import MediaRepositoryStore | from synapse.storage.media_repository import MediaRepositoryStore | ||||||
| from synapse.util.async import sleep | from synapse.util.async import sleep | ||||||
| from synapse.util.httpresourcetree import create_resource_tree | from synapse.util.httpresourcetree import create_resource_tree | ||||||
| from synapse.util.logcontext import LoggingContext | from synapse.util.logcontext import LoggingContext, PreserveLoggingContext | ||||||
| from synapse.util.manhole import manhole | from synapse.util.manhole import manhole | ||||||
| from synapse.util.rlimit import change_resource_limit | from synapse.util.rlimit import change_resource_limit | ||||||
| from synapse.util.versionstring import get_version_string | from synapse.util.versionstring import get_version_string | ||||||
| @ -59,6 +60,7 @@ logger = logging.getLogger("synapse.app.media_repository") | |||||||
| class MediaRepositorySlavedStore( | class MediaRepositorySlavedStore( | ||||||
|     SlavedApplicationServiceStore, |     SlavedApplicationServiceStore, | ||||||
|     SlavedRegistrationStore, |     SlavedRegistrationStore, | ||||||
|  |     TransactionStore, | ||||||
|     BaseSlavedStore, |     BaseSlavedStore, | ||||||
|     MediaRepositoryStore, |     MediaRepositoryStore, | ||||||
|     ClientIpStore, |     ClientIpStore, | ||||||
| @ -168,7 +170,7 @@ def start(config_options): | |||||||
| 
 | 
 | ||||||
|     assert config.worker_app == "synapse.app.media_repository" |     assert config.worker_app == "synapse.app.media_repository" | ||||||
| 
 | 
 | ||||||
|     setup_logging(config.worker_log_config, config.worker_log_file) |     setup_logging(config, use_worker_options=True) | ||||||
| 
 | 
 | ||||||
|     events.USE_FROZEN_DICTS = config.use_frozen_dicts |     events.USE_FROZEN_DICTS = config.use_frozen_dicts | ||||||
| 
 | 
 | ||||||
| @ -190,7 +192,11 @@ def start(config_options): | |||||||
|     ss.start_listening(config.worker_listeners) |     ss.start_listening(config.worker_listeners) | ||||||
| 
 | 
 | ||||||
|     def run(): |     def run(): | ||||||
|         with LoggingContext("run"): |         # make sure that we run the reactor with the sentinel log context, | ||||||
|  |         # otherwise other PreserveLoggingContext instances will get confused | ||||||
|  |         # and complain when they see the logcontext arbitrarily swapping | ||||||
|  |         # between the sentinel and `run` logcontexts. | ||||||
|  |         with PreserveLoggingContext(): | ||||||
|             logger.info("Running") |             logger.info("Running") | ||||||
|             change_resource_limit(config.soft_file_limit) |             change_resource_limit(config.soft_file_limit) | ||||||
|             if config.gc_thresholds: |             if config.gc_thresholds: | ||||||
|  | |||||||
| @ -31,7 +31,8 @@ from synapse.storage.engines import create_engine | |||||||
| from synapse.storage import DataStore | from synapse.storage import DataStore | ||||||
| from synapse.util.async import sleep | from synapse.util.async import sleep | ||||||
| from synapse.util.httpresourcetree import create_resource_tree | from synapse.util.httpresourcetree import create_resource_tree | ||||||
| from synapse.util.logcontext import LoggingContext, preserve_fn | from synapse.util.logcontext import LoggingContext, preserve_fn, \ | ||||||
|  |     PreserveLoggingContext | ||||||
| from synapse.util.manhole import manhole | from synapse.util.manhole import manhole | ||||||
| from synapse.util.rlimit import change_resource_limit | from synapse.util.rlimit import change_resource_limit | ||||||
| from synapse.util.versionstring import get_version_string | from synapse.util.versionstring import get_version_string | ||||||
| @ -245,7 +246,7 @@ def start(config_options): | |||||||
| 
 | 
 | ||||||
|     assert config.worker_app == "synapse.app.pusher" |     assert config.worker_app == "synapse.app.pusher" | ||||||
| 
 | 
 | ||||||
|     setup_logging(config.worker_log_config, config.worker_log_file) |     setup_logging(config, use_worker_options=True) | ||||||
| 
 | 
 | ||||||
|     events.USE_FROZEN_DICTS = config.use_frozen_dicts |     events.USE_FROZEN_DICTS = config.use_frozen_dicts | ||||||
| 
 | 
 | ||||||
| @ -275,7 +276,11 @@ def start(config_options): | |||||||
|     ps.start_listening(config.worker_listeners) |     ps.start_listening(config.worker_listeners) | ||||||
| 
 | 
 | ||||||
|     def run(): |     def run(): | ||||||
|         with LoggingContext("run"): |         # make sure that we run the reactor with the sentinel log context, | ||||||
|  |         # otherwise other PreserveLoggingContext instances will get confused | ||||||
|  |         # and complain when they see the logcontext arbitrarily swapping | ||||||
|  |         # between the sentinel and `run` logcontexts. | ||||||
|  |         with PreserveLoggingContext(): | ||||||
|             logger.info("Running") |             logger.info("Running") | ||||||
|             change_resource_limit(config.soft_file_limit) |             change_resource_limit(config.soft_file_limit) | ||||||
|             if config.gc_thresholds: |             if config.gc_thresholds: | ||||||
|  | |||||||
| @ -20,7 +20,6 @@ from synapse.api.constants import EventTypes, PresenceState | |||||||
| from synapse.config._base import ConfigError | from synapse.config._base import ConfigError | ||||||
| from synapse.config.homeserver import HomeServerConfig | from synapse.config.homeserver import HomeServerConfig | ||||||
| from synapse.config.logger import setup_logging | from synapse.config.logger import setup_logging | ||||||
| from synapse.events import FrozenEvent |  | ||||||
| from synapse.handlers.presence import PresenceHandler | from synapse.handlers.presence import PresenceHandler | ||||||
| from synapse.http.site import SynapseSite | from synapse.http.site import SynapseSite | ||||||
| from synapse.http.server import JsonResource | from synapse.http.server import JsonResource | ||||||
| @ -48,7 +47,8 @@ from synapse.storage.presence import PresenceStore, UserPresenceState | |||||||
| from synapse.storage.roommember import RoomMemberStore | from synapse.storage.roommember import RoomMemberStore | ||||||
| from synapse.util.async import sleep | from synapse.util.async import sleep | ||||||
| from synapse.util.httpresourcetree import create_resource_tree | from synapse.util.httpresourcetree import create_resource_tree | ||||||
| from synapse.util.logcontext import LoggingContext, preserve_fn | from synapse.util.logcontext import LoggingContext, preserve_fn, \ | ||||||
|  |     PreserveLoggingContext | ||||||
| from synapse.util.manhole import manhole | from synapse.util.manhole import manhole | ||||||
| from synapse.util.rlimit import change_resource_limit | from synapse.util.rlimit import change_resource_limit | ||||||
| from synapse.util.stringutils import random_string | from synapse.util.stringutils import random_string | ||||||
| @ -399,8 +399,7 @@ class SynchrotronServer(HomeServer): | |||||||
|                 position = row[position_index] |                 position = row[position_index] | ||||||
|                 user_id = row[user_index] |                 user_id = row[user_index] | ||||||
| 
 | 
 | ||||||
|                 rooms = yield store.get_rooms_for_user(user_id) |                 room_ids = yield store.get_rooms_for_user(user_id) | ||||||
|                 room_ids = [r.room_id for r in rooms] |  | ||||||
| 
 | 
 | ||||||
|                 notifier.on_new_event( |                 notifier.on_new_event( | ||||||
|                     "device_list_key", position, rooms=room_ids, |                     "device_list_key", position, rooms=room_ids, | ||||||
| @ -411,11 +410,16 @@ class SynchrotronServer(HomeServer): | |||||||
|             stream = result.get("events") |             stream = result.get("events") | ||||||
|             if stream: |             if stream: | ||||||
|                 max_position = stream["position"] |                 max_position = stream["position"] | ||||||
|  | 
 | ||||||
|  |                 event_map = yield store.get_events([row[1] for row in stream["rows"]]) | ||||||
|  | 
 | ||||||
|                 for row in stream["rows"]: |                 for row in stream["rows"]: | ||||||
|                     position = row[0] |                     position = row[0] | ||||||
|                     internal = json.loads(row[1]) |                     event_id = row[1] | ||||||
|                     event_json = json.loads(row[2]) |                     event = event_map.get(event_id, None) | ||||||
|                     event = FrozenEvent(event_json, internal_metadata_dict=internal) |                     if not event: | ||||||
|  |                         continue | ||||||
|  | 
 | ||||||
|                     extra_users = () |                     extra_users = () | ||||||
|                     if event.type == EventTypes.Member: |                     if event.type == EventTypes.Member: | ||||||
|                         extra_users = (event.state_key,) |                         extra_users = (event.state_key,) | ||||||
| @ -478,7 +482,7 @@ def start(config_options): | |||||||
| 
 | 
 | ||||||
|     assert config.worker_app == "synapse.app.synchrotron" |     assert config.worker_app == "synapse.app.synchrotron" | ||||||
| 
 | 
 | ||||||
|     setup_logging(config.worker_log_config, config.worker_log_file) |     setup_logging(config, use_worker_options=True) | ||||||
| 
 | 
 | ||||||
|     synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts |     synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts | ||||||
| 
 | 
 | ||||||
| @ -497,7 +501,11 @@ def start(config_options): | |||||||
|     ss.start_listening(config.worker_listeners) |     ss.start_listening(config.worker_listeners) | ||||||
| 
 | 
 | ||||||
|     def run(): |     def run(): | ||||||
|         with LoggingContext("run"): |         # make sure that we run the reactor with the sentinel log context, | ||||||
|  |         # otherwise other PreserveLoggingContext instances will get confused | ||||||
|  |         # and complain when they see the logcontext arbitrarily swapping | ||||||
|  |         # between the sentinel and `run` logcontexts. | ||||||
|  |         with PreserveLoggingContext(): | ||||||
|             logger.info("Running") |             logger.info("Running") | ||||||
|             change_resource_limit(config.soft_file_limit) |             change_resource_limit(config.soft_file_limit) | ||||||
|             if config.gc_thresholds: |             if config.gc_thresholds: | ||||||
|  | |||||||
| @ -23,14 +23,27 @@ import signal | |||||||
| import subprocess | import subprocess | ||||||
| import sys | import sys | ||||||
| import yaml | import yaml | ||||||
|  | import errno | ||||||
|  | import time | ||||||
| 
 | 
 | ||||||
| SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"] | SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"] | ||||||
| 
 | 
 | ||||||
| GREEN = "\x1b[1;32m" | GREEN = "\x1b[1;32m" | ||||||
|  | YELLOW = "\x1b[1;33m" | ||||||
| RED = "\x1b[1;31m" | RED = "\x1b[1;31m" | ||||||
| NORMAL = "\x1b[m" | NORMAL = "\x1b[m" | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def pid_running(pid): | ||||||
|  |     try: | ||||||
|  |         os.kill(pid, 0) | ||||||
|  |         return True | ||||||
|  |     except OSError, err: | ||||||
|  |         if err.errno == errno.EPERM: | ||||||
|  |             return True | ||||||
|  |         return False | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def write(message, colour=NORMAL, stream=sys.stdout): | def write(message, colour=NORMAL, stream=sys.stdout): | ||||||
|     if colour == NORMAL: |     if colour == NORMAL: | ||||||
|         stream.write(message + "\n") |         stream.write(message + "\n") | ||||||
| @ -38,6 +51,11 @@ def write(message, colour=NORMAL, stream=sys.stdout): | |||||||
|         stream.write(colour + message + NORMAL + "\n") |         stream.write(colour + message + NORMAL + "\n") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def abort(message, colour=RED, stream=sys.stderr): | ||||||
|  |     write(message, colour, stream) | ||||||
|  |     sys.exit(1) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def start(configfile): | def start(configfile): | ||||||
|     write("Starting ...") |     write("Starting ...") | ||||||
|     args = SYNAPSE |     args = SYNAPSE | ||||||
| @ -45,7 +63,8 @@ def start(configfile): | |||||||
| 
 | 
 | ||||||
|     try: |     try: | ||||||
|         subprocess.check_call(args) |         subprocess.check_call(args) | ||||||
|         write("started synapse.app.homeserver(%r)" % (configfile,), colour=GREEN) |         write("started synapse.app.homeserver(%r)" % | ||||||
|  |               (configfile,), colour=GREEN) | ||||||
|     except subprocess.CalledProcessError as e: |     except subprocess.CalledProcessError as e: | ||||||
|         write( |         write( | ||||||
|             "error starting (exit code: %d); see above for logs" % e.returncode, |             "error starting (exit code: %d); see above for logs" % e.returncode, | ||||||
| @ -76,8 +95,16 @@ def start_worker(app, configfile, worker_configfile): | |||||||
| def stop(pidfile, app): | def stop(pidfile, app): | ||||||
|     if os.path.exists(pidfile): |     if os.path.exists(pidfile): | ||||||
|         pid = int(open(pidfile).read()) |         pid = int(open(pidfile).read()) | ||||||
|  |         try: | ||||||
|             os.kill(pid, signal.SIGTERM) |             os.kill(pid, signal.SIGTERM) | ||||||
|             write("stopped %s" % (app,), colour=GREEN) |             write("stopped %s" % (app,), colour=GREEN) | ||||||
|  |         except OSError, err: | ||||||
|  |             if err.errno == errno.ESRCH: | ||||||
|  |                 write("%s not running" % (app,), colour=YELLOW) | ||||||
|  |             elif err.errno == errno.EPERM: | ||||||
|  |                 abort("Cannot stop %s: Operation not permitted" % (app,)) | ||||||
|  |             else: | ||||||
|  |                 abort("Cannot stop %s: Unknown error" % (app,)) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| Worker = collections.namedtuple("Worker", [ | Worker = collections.namedtuple("Worker", [ | ||||||
| @ -190,7 +217,19 @@ def main(): | |||||||
|         if start_stop_synapse: |         if start_stop_synapse: | ||||||
|             stop(pidfile, "synapse.app.homeserver") |             stop(pidfile, "synapse.app.homeserver") | ||||||
| 
 | 
 | ||||||
|         # TODO: Wait for synapse to actually shutdown before starting it again |     # Wait for synapse to actually shutdown before starting it again | ||||||
|  |     if action == "restart": | ||||||
|  |         running_pids = [] | ||||||
|  |         if start_stop_synapse and os.path.exists(pidfile): | ||||||
|  |             running_pids.append(int(open(pidfile).read())) | ||||||
|  |         for worker in workers: | ||||||
|  |             if os.path.exists(worker.pidfile): | ||||||
|  |                 running_pids.append(int(open(worker.pidfile).read())) | ||||||
|  |         if len(running_pids) > 0: | ||||||
|  |             write("Waiting for process to exit before restarting...") | ||||||
|  |             for running_pid in running_pids: | ||||||
|  |                 while pid_running(running_pid): | ||||||
|  |                     time.sleep(0.2) | ||||||
| 
 | 
 | ||||||
|     if action == "start" or action == "restart": |     if action == "start" or action == "restart": | ||||||
|         if start_stop_synapse: |         if start_stop_synapse: | ||||||
|  | |||||||
| @ -45,7 +45,6 @@ handlers: | |||||||
|     maxBytes: 104857600 |     maxBytes: 104857600 | ||||||
|     backupCount: 10 |     backupCount: 10 | ||||||
|     filters: [context] |     filters: [context] | ||||||
|     level: INFO |  | ||||||
|   console: |   console: | ||||||
|     class: logging.StreamHandler |     class: logging.StreamHandler | ||||||
|     formatter: precise |     formatter: precise | ||||||
| @ -56,6 +55,8 @@ loggers: | |||||||
|         level: INFO |         level: INFO | ||||||
| 
 | 
 | ||||||
|     synapse.storage.SQL: |     synapse.storage.SQL: | ||||||
|  |         # beware: increasing this to DEBUG will make synapse log sensitive | ||||||
|  |         # information such as access tokens. | ||||||
|         level: INFO |         level: INFO | ||||||
| 
 | 
 | ||||||
| root: | root: | ||||||
| @ -68,6 +69,7 @@ class LoggingConfig(Config): | |||||||
| 
 | 
 | ||||||
|     def read_config(self, config): |     def read_config(self, config): | ||||||
|         self.verbosity = config.get("verbose", 0) |         self.verbosity = config.get("verbose", 0) | ||||||
|  |         self.no_redirect_stdio = config.get("no_redirect_stdio", False) | ||||||
|         self.log_config = self.abspath(config.get("log_config")) |         self.log_config = self.abspath(config.get("log_config")) | ||||||
|         self.log_file = self.abspath(config.get("log_file")) |         self.log_file = self.abspath(config.get("log_file")) | ||||||
| 
 | 
 | ||||||
| @ -77,10 +79,10 @@ class LoggingConfig(Config): | |||||||
|             os.path.join(config_dir_path, server_name + ".log.config") |             os.path.join(config_dir_path, server_name + ".log.config") | ||||||
|         ) |         ) | ||||||
|         return """ |         return """ | ||||||
|         # Logging verbosity level. |         # Logging verbosity level. Ignored if log_config is specified. | ||||||
|         verbose: 0 |         verbose: 0 | ||||||
| 
 | 
 | ||||||
|         # File to write logging to |         # File to write logging to. Ignored if log_config is specified. | ||||||
|         log_file: "%(log_file)s" |         log_file: "%(log_file)s" | ||||||
| 
 | 
 | ||||||
|         # A yaml python logging config file |         # A yaml python logging config file | ||||||
| @ -90,6 +92,8 @@ class LoggingConfig(Config): | |||||||
|     def read_arguments(self, args): |     def read_arguments(self, args): | ||||||
|         if args.verbose is not None: |         if args.verbose is not None: | ||||||
|             self.verbosity = args.verbose |             self.verbosity = args.verbose | ||||||
|  |         if args.no_redirect_stdio is not None: | ||||||
|  |             self.no_redirect_stdio = args.no_redirect_stdio | ||||||
|         if args.log_config is not None: |         if args.log_config is not None: | ||||||
|             self.log_config = args.log_config |             self.log_config = args.log_config | ||||||
|         if args.log_file is not None: |         if args.log_file is not None: | ||||||
| @ -99,16 +103,22 @@ class LoggingConfig(Config): | |||||||
|         logging_group = parser.add_argument_group("logging") |         logging_group = parser.add_argument_group("logging") | ||||||
|         logging_group.add_argument( |         logging_group.add_argument( | ||||||
|             '-v', '--verbose', dest="verbose", action='count', |             '-v', '--verbose', dest="verbose", action='count', | ||||||
|             help="The verbosity level." |             help="The verbosity level. Specify multiple times to increase " | ||||||
|  |             "verbosity. (Ignored if --log-config is specified.)" | ||||||
|         ) |         ) | ||||||
|         logging_group.add_argument( |         logging_group.add_argument( | ||||||
|             '-f', '--log-file', dest="log_file", |             '-f', '--log-file', dest="log_file", | ||||||
|             help="File to log to." |             help="File to log to. (Ignored if --log-config is specified.)" | ||||||
|         ) |         ) | ||||||
|         logging_group.add_argument( |         logging_group.add_argument( | ||||||
|             '--log-config', dest="log_config", default=None, |             '--log-config', dest="log_config", default=None, | ||||||
|             help="Python logging config file" |             help="Python logging config file" | ||||||
|         ) |         ) | ||||||
|  |         logging_group.add_argument( | ||||||
|  |             '-n', '--no-redirect-stdio', | ||||||
|  |             action='store_true', default=None, | ||||||
|  |             help="Do not redirect stdout/stderr to the log" | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|     def generate_files(self, config): |     def generate_files(self, config): | ||||||
|         log_config = config.get("log_config") |         log_config = config.get("log_config") | ||||||
| @ -118,11 +128,22 @@ class LoggingConfig(Config): | |||||||
|                     DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"]) |                     DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"]) | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|     def setup_logging(self): |  | ||||||
|         setup_logging(self.log_config, self.log_file, self.verbosity) |  | ||||||
| 
 | 
 | ||||||
|  | def setup_logging(config, use_worker_options=False): | ||||||
|  |     """ Set up python logging | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         config (LoggingConfig | synapse.config.workers.WorkerConfig): | ||||||
|  |             configuration data | ||||||
|  | 
 | ||||||
|  |         use_worker_options (bool): True to use 'worker_log_config' and | ||||||
|  |             'worker_log_file' options instead of 'log_config' and 'log_file'. | ||||||
|  |     """ | ||||||
|  |     log_config = (config.worker_log_config if use_worker_options | ||||||
|  |                   else config.log_config) | ||||||
|  |     log_file = (config.worker_log_file if use_worker_options | ||||||
|  |                 else config.log_file) | ||||||
| 
 | 
 | ||||||
| def setup_logging(log_config=None, log_file=None, verbosity=None): |  | ||||||
|     log_format = ( |     log_format = ( | ||||||
|         "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s" |         "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s" | ||||||
|         " - %(message)s" |         " - %(message)s" | ||||||
| @ -131,9 +152,9 @@ def setup_logging(log_config=None, log_file=None, verbosity=None): | |||||||
| 
 | 
 | ||||||
|         level = logging.INFO |         level = logging.INFO | ||||||
|         level_for_storage = logging.INFO |         level_for_storage = logging.INFO | ||||||
|         if verbosity: |         if config.verbosity: | ||||||
|             level = logging.DEBUG |             level = logging.DEBUG | ||||||
|             if verbosity > 1: |             if config.verbosity > 1: | ||||||
|                 level_for_storage = logging.DEBUG |                 level_for_storage = logging.DEBUG | ||||||
| 
 | 
 | ||||||
|         # FIXME: we need a logging.WARN for a -q quiet option |         # FIXME: we need a logging.WARN for a -q quiet option | ||||||
| @ -153,14 +174,6 @@ def setup_logging(log_config=None, log_file=None, verbosity=None): | |||||||
|                 logger.info("Closing log file due to SIGHUP") |                 logger.info("Closing log file due to SIGHUP") | ||||||
|                 handler.doRollover() |                 handler.doRollover() | ||||||
|                 logger.info("Opened new log file due to SIGHUP") |                 logger.info("Opened new log file due to SIGHUP") | ||||||
| 
 |  | ||||||
|             # TODO(paul): obviously this is a terrible mechanism for |  | ||||||
|             #   stealing SIGHUP, because it means no other part of synapse |  | ||||||
|             #   can use it instead. If we want to catch SIGHUP anywhere |  | ||||||
|             #   else as well, I'd suggest we find a nicer way to broadcast |  | ||||||
|             #   it around. |  | ||||||
|             if getattr(signal, "SIGHUP"): |  | ||||||
|                 signal.signal(signal.SIGHUP, sighup) |  | ||||||
|         else: |         else: | ||||||
|             handler = logging.StreamHandler() |             handler = logging.StreamHandler() | ||||||
|         handler.setFormatter(formatter) |         handler.setFormatter(formatter) | ||||||
| @ -169,9 +182,26 @@ def setup_logging(log_config=None, log_file=None, verbosity=None): | |||||||
| 
 | 
 | ||||||
|         logger.addHandler(handler) |         logger.addHandler(handler) | ||||||
|     else: |     else: | ||||||
|  |         def load_log_config(): | ||||||
|             with open(log_config, 'r') as f: |             with open(log_config, 'r') as f: | ||||||
|                 logging.config.dictConfig(yaml.load(f)) |                 logging.config.dictConfig(yaml.load(f)) | ||||||
| 
 | 
 | ||||||
|  |         def sighup(signum, stack): | ||||||
|  |             # it might be better to use a file watcher or something for this. | ||||||
|  |             logging.info("Reloading log config from %s due to SIGHUP", | ||||||
|  |                          log_config) | ||||||
|  |             load_log_config() | ||||||
|  | 
 | ||||||
|  |         load_log_config() | ||||||
|  | 
 | ||||||
|  |     # TODO(paul): obviously this is a terrible mechanism for | ||||||
|  |     #   stealing SIGHUP, because it means no other part of synapse | ||||||
|  |     #   can use it instead. If we want to catch SIGHUP anywhere | ||||||
|  |     #   else as well, I'd suggest we find a nicer way to broadcast | ||||||
|  |     #   it around. | ||||||
|  |     if getattr(signal, "SIGHUP"): | ||||||
|  |         signal.signal(signal.SIGHUP, sighup) | ||||||
|  | 
 | ||||||
|     # It's critical to point twisted's internal logging somewhere, otherwise it |     # It's critical to point twisted's internal logging somewhere, otherwise it | ||||||
|     # stacks up and leaks kup to 64K object; |     # stacks up and leaks kup to 64K object; | ||||||
|     # see: https://twistedmatrix.com/trac/ticket/8164 |     # see: https://twistedmatrix.com/trac/ticket/8164 | ||||||
| @ -183,4 +213,7 @@ def setup_logging(log_config=None, log_file=None, verbosity=None): | |||||||
|     # |     # | ||||||
|     # However this may not be too much of a problem if we are just writing to a file. |     # However this may not be too much of a problem if we are just writing to a file. | ||||||
|     observer = STDLibLogObserver() |     observer = STDLibLogObserver() | ||||||
|     globalLogBeginner.beginLoggingTo([observer]) |     globalLogBeginner.beginLoggingTo( | ||||||
|  |         [observer], | ||||||
|  |         redirectStandardIO=not config.no_redirect_stdio, | ||||||
|  |     ) | ||||||
|  | |||||||
| @ -15,7 +15,6 @@ | |||||||
| 
 | 
 | ||||||
| from synapse.crypto.keyclient import fetch_server_key | from synapse.crypto.keyclient import fetch_server_key | ||||||
| from synapse.api.errors import SynapseError, Codes | from synapse.api.errors import SynapseError, Codes | ||||||
| from synapse.util.retryutils import get_retry_limiter |  | ||||||
| from synapse.util import unwrapFirstError | from synapse.util import unwrapFirstError | ||||||
| from synapse.util.async import ObservableDeferred | from synapse.util.async import ObservableDeferred | ||||||
| from synapse.util.logcontext import ( | from synapse.util.logcontext import ( | ||||||
| @ -96,10 +95,11 @@ class Keyring(object): | |||||||
|         verify_requests = [] |         verify_requests = [] | ||||||
| 
 | 
 | ||||||
|         for server_name, json_object in server_and_json: |         for server_name, json_object in server_and_json: | ||||||
|             logger.debug("Verifying for %s", server_name) |  | ||||||
| 
 | 
 | ||||||
|             key_ids = signature_ids(json_object, server_name) |             key_ids = signature_ids(json_object, server_name) | ||||||
|             if not key_ids: |             if not key_ids: | ||||||
|  |                 logger.warn("Request from %s: no supported signature keys", | ||||||
|  |                             server_name) | ||||||
|                 deferred = defer.fail(SynapseError( |                 deferred = defer.fail(SynapseError( | ||||||
|                     400, |                     400, | ||||||
|                     "Not signed with a supported algorithm", |                     "Not signed with a supported algorithm", | ||||||
| @ -108,6 +108,9 @@ class Keyring(object): | |||||||
|             else: |             else: | ||||||
|                 deferred = defer.Deferred() |                 deferred = defer.Deferred() | ||||||
| 
 | 
 | ||||||
|  |             logger.debug("Verifying for %s with key_ids %s", | ||||||
|  |                          server_name, key_ids) | ||||||
|  | 
 | ||||||
|             verify_request = VerifyKeyRequest( |             verify_request = VerifyKeyRequest( | ||||||
|                 server_name, key_ids, json_object, deferred |                 server_name, key_ids, json_object, deferred | ||||||
|             ) |             ) | ||||||
| @ -142,6 +145,9 @@ class Keyring(object): | |||||||
| 
 | 
 | ||||||
|             json_object = verify_request.json_object |             json_object = verify_request.json_object | ||||||
| 
 | 
 | ||||||
|  |             logger.debug("Got key %s %s:%s for server %s, verifying" % ( | ||||||
|  |                 key_id, verify_key.alg, verify_key.version, server_name, | ||||||
|  |             )) | ||||||
|             try: |             try: | ||||||
|                 verify_signed_json(json_object, server_name, verify_key) |                 verify_signed_json(json_object, server_name, verify_key) | ||||||
|             except: |             except: | ||||||
| @ -231,8 +237,14 @@ class Keyring(object): | |||||||
|             d.addBoth(rm, server_name) |             d.addBoth(rm, server_name) | ||||||
| 
 | 
 | ||||||
|     def get_server_verify_keys(self, verify_requests): |     def get_server_verify_keys(self, verify_requests): | ||||||
|         """Takes a dict of KeyGroups and tries to find at least one key for |         """Tries to find at least one key for each verify request | ||||||
|         each group. | 
 | ||||||
|  |         For each verify_request, verify_request.deferred is called back with | ||||||
|  |         params (server_name, key_id, VerifyKey) if a key is found, or errbacked | ||||||
|  |         with a SynapseError if none of the keys are found. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             verify_requests (list[VerifyKeyRequest]): list of verify requests | ||||||
|         """ |         """ | ||||||
| 
 | 
 | ||||||
|         # These are functions that produce keys given a list of key ids |         # These are functions that produce keys given a list of key ids | ||||||
| @ -245,8 +257,11 @@ class Keyring(object): | |||||||
|         @defer.inlineCallbacks |         @defer.inlineCallbacks | ||||||
|         def do_iterations(): |         def do_iterations(): | ||||||
|             with Measure(self.clock, "get_server_verify_keys"): |             with Measure(self.clock, "get_server_verify_keys"): | ||||||
|  |                 # dict[str, dict[str, VerifyKey]]: results so far. | ||||||
|  |                 # map server_name -> key_id -> VerifyKey | ||||||
|                 merged_results = {} |                 merged_results = {} | ||||||
| 
 | 
 | ||||||
|  |                 # dict[str, set(str)]: keys to fetch for each server | ||||||
|                 missing_keys = {} |                 missing_keys = {} | ||||||
|                 for verify_request in verify_requests: |                 for verify_request in verify_requests: | ||||||
|                     missing_keys.setdefault(verify_request.server_name, set()).update( |                     missing_keys.setdefault(verify_request.server_name, set()).update( | ||||||
| @ -308,6 +323,16 @@ class Keyring(object): | |||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def get_keys_from_store(self, server_name_and_key_ids): |     def get_keys_from_store(self, server_name_and_key_ids): | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             server_name_and_key_ids (list[(str, iterable[str])]): | ||||||
|  |                 list of (server_name, iterable[key_id]) tuples to fetch keys for | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             Deferred: resolves to dict[str, dict[str, VerifyKey]]: map from | ||||||
|  |                 server_name -> key_id -> VerifyKey | ||||||
|  |         """ | ||||||
|         res = yield preserve_context_over_deferred(defer.gatherResults( |         res = yield preserve_context_over_deferred(defer.gatherResults( | ||||||
|             [ |             [ | ||||||
|                 preserve_fn(self.store.get_server_verify_keys)( |                 preserve_fn(self.store.get_server_verify_keys)( | ||||||
| @ -356,12 +381,6 @@ class Keyring(object): | |||||||
|     def get_keys_from_server(self, server_name_and_key_ids): |     def get_keys_from_server(self, server_name_and_key_ids): | ||||||
|         @defer.inlineCallbacks |         @defer.inlineCallbacks | ||||||
|         def get_key(server_name, key_ids): |         def get_key(server_name, key_ids): | ||||||
|             limiter = yield get_retry_limiter( |  | ||||||
|                 server_name, |  | ||||||
|                 self.clock, |  | ||||||
|                 self.store, |  | ||||||
|             ) |  | ||||||
|             with limiter: |  | ||||||
|             keys = None |             keys = None | ||||||
|             try: |             try: | ||||||
|                 keys = yield self.get_server_verify_key_v2_direct( |                 keys = yield self.get_server_verify_key_v2_direct( | ||||||
|  | |||||||
| @ -15,6 +15,32 @@ | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class EventContext(object): | class EventContext(object): | ||||||
|  |     """ | ||||||
|  |     Attributes: | ||||||
|  |         current_state_ids (dict[(str, str), str]): | ||||||
|  |             The current state map including the current event. | ||||||
|  |             (type, state_key) -> event_id | ||||||
|  | 
 | ||||||
|  |         prev_state_ids (dict[(str, str), str]): | ||||||
|  |             The current state map excluding the current event. | ||||||
|  |             (type, state_key) -> event_id | ||||||
|  | 
 | ||||||
|  |         state_group (int): state group id | ||||||
|  |         rejected (bool|str): A rejection reason if the event was rejected, else | ||||||
|  |             False | ||||||
|  | 
 | ||||||
|  |         push_actions (list[(str, list[object])]): list of (user_id, actions) | ||||||
|  |             tuples | ||||||
|  | 
 | ||||||
|  |         prev_group (int): Previously persisted state group. ``None`` for an | ||||||
|  |             outlier. | ||||||
|  |         delta_ids (dict[(str, str), str]): Delta from ``prev_group``. | ||||||
|  |             (type, state_key) -> event_id. ``None`` for an outlier. | ||||||
|  | 
 | ||||||
|  |         prev_state_events (?): XXX: is this ever set to anything other than | ||||||
|  |             the empty list? | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|     __slots__ = [ |     __slots__ = [ | ||||||
|         "current_state_ids", |         "current_state_ids", | ||||||
|         "prev_state_ids", |         "prev_state_ids", | ||||||
|  | |||||||
| @ -29,7 +29,7 @@ from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred | |||||||
| from synapse.events import FrozenEvent, builder | from synapse.events import FrozenEvent, builder | ||||||
| import synapse.metrics | import synapse.metrics | ||||||
| 
 | 
 | ||||||
| from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination | from synapse.util.retryutils import NotRetryingDestination | ||||||
| 
 | 
 | ||||||
| import copy | import copy | ||||||
| import itertools | import itertools | ||||||
| @ -88,7 +88,7 @@ class FederationClient(FederationBase): | |||||||
| 
 | 
 | ||||||
|     @log_function |     @log_function | ||||||
|     def make_query(self, destination, query_type, args, |     def make_query(self, destination, query_type, args, | ||||||
|                    retry_on_dns_fail=False): |                    retry_on_dns_fail=False, ignore_backoff=False): | ||||||
|         """Sends a federation Query to a remote homeserver of the given type |         """Sends a federation Query to a remote homeserver of the given type | ||||||
|         and arguments. |         and arguments. | ||||||
| 
 | 
 | ||||||
| @ -98,6 +98,8 @@ class FederationClient(FederationBase): | |||||||
|                 handler name used in register_query_handler(). |                 handler name used in register_query_handler(). | ||||||
|             args (dict): Mapping of strings to strings containing the details |             args (dict): Mapping of strings to strings containing the details | ||||||
|                 of the query request. |                 of the query request. | ||||||
|  |             ignore_backoff (bool): true to ignore the historical backoff data | ||||||
|  |                 and try the request anyway. | ||||||
| 
 | 
 | ||||||
|         Returns: |         Returns: | ||||||
|             a Deferred which will eventually yield a JSON object from the |             a Deferred which will eventually yield a JSON object from the | ||||||
| @ -106,7 +108,8 @@ class FederationClient(FederationBase): | |||||||
|         sent_queries_counter.inc(query_type) |         sent_queries_counter.inc(query_type) | ||||||
| 
 | 
 | ||||||
|         return self.transport_layer.make_query( |         return self.transport_layer.make_query( | ||||||
|             destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail |             destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail, | ||||||
|  |             ignore_backoff=ignore_backoff, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     @log_function |     @log_function | ||||||
| @ -234,13 +237,6 @@ class FederationClient(FederationBase): | |||||||
|                 continue |                 continue | ||||||
| 
 | 
 | ||||||
|             try: |             try: | ||||||
|                 limiter = yield get_retry_limiter( |  | ||||||
|                     destination, |  | ||||||
|                     self._clock, |  | ||||||
|                     self.store, |  | ||||||
|                 ) |  | ||||||
| 
 |  | ||||||
|                 with limiter: |  | ||||||
|                 transaction_data = yield self.transport_layer.get_event( |                 transaction_data = yield self.transport_layer.get_event( | ||||||
|                     destination, event_id, timeout=timeout, |                     destination, event_id, timeout=timeout, | ||||||
|                 ) |                 ) | ||||||
|  | |||||||
| @ -52,7 +52,6 @@ class FederationServer(FederationBase): | |||||||
| 
 | 
 | ||||||
|         self.auth = hs.get_auth() |         self.auth = hs.get_auth() | ||||||
| 
 | 
 | ||||||
|         self._room_pdu_linearizer = Linearizer("fed_room_pdu") |  | ||||||
|         self._server_linearizer = Linearizer("fed_server") |         self._server_linearizer = Linearizer("fed_server") | ||||||
| 
 | 
 | ||||||
|         # We cache responses to state queries, as they take a while and often |         # We cache responses to state queries, as they take a while and often | ||||||
| @ -147,11 +146,15 @@ class FederationServer(FederationBase): | |||||||
|             # check that it's actually being sent from a valid destination to |             # check that it's actually being sent from a valid destination to | ||||||
|             # workaround bug #1753 in 0.18.5 and 0.18.6 |             # workaround bug #1753 in 0.18.5 and 0.18.6 | ||||||
|             if transaction.origin != get_domain_from_id(pdu.event_id): |             if transaction.origin != get_domain_from_id(pdu.event_id): | ||||||
|  |                 # We continue to accept join events from any server; this is | ||||||
|  |                 # necessary for the federation join dance to work correctly. | ||||||
|  |                 # (When we join over federation, the "helper" server is | ||||||
|  |                 # responsible for sending out the join event, rather than the | ||||||
|  |                 # origin. See bug #1893). | ||||||
|                 if not ( |                 if not ( | ||||||
|                     pdu.type == 'm.room.member' and |                     pdu.type == 'm.room.member' and | ||||||
|                     pdu.content and |                     pdu.content and | ||||||
|                     pdu.content.get("membership", None) == 'join' and |                     pdu.content.get("membership", None) == 'join' | ||||||
|                     self.hs.is_mine_id(pdu.state_key) |  | ||||||
|                 ): |                 ): | ||||||
|                     logger.info( |                     logger.info( | ||||||
|                         "Discarding PDU %s from invalid origin %s", |                         "Discarding PDU %s from invalid origin %s", | ||||||
| @ -165,7 +168,7 @@ class FederationServer(FederationBase): | |||||||
|                     ) |                     ) | ||||||
| 
 | 
 | ||||||
|             try: |             try: | ||||||
|                 yield self._handle_new_pdu(transaction.origin, pdu) |                 yield self._handle_received_pdu(transaction.origin, pdu) | ||||||
|                 results.append({}) |                 results.append({}) | ||||||
|             except FederationError as e: |             except FederationError as e: | ||||||
|                 self.send_failure(e, transaction.origin) |                 self.send_failure(e, transaction.origin) | ||||||
| @ -497,27 +500,16 @@ class FederationServer(FederationBase): | |||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     @log_function |     def _handle_received_pdu(self, origin, pdu): | ||||||
|     def _handle_new_pdu(self, origin, pdu, get_missing=True): |         """ Process a PDU received in a federation /send/ transaction. | ||||||
| 
 | 
 | ||||||
|         # We reprocess pdus when we have seen them only as outliers |         Args: | ||||||
|         existing = yield self._get_persisted_pdu( |             origin (str): server which sent the pdu | ||||||
|             origin, pdu.event_id, do_auth=False |             pdu (FrozenEvent): received pdu | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|         # FIXME: Currently we fetch an event again when we already have it |  | ||||||
|         # if it has been marked as an outlier. |  | ||||||
| 
 |  | ||||||
|         already_seen = ( |  | ||||||
|             existing and ( |  | ||||||
|                 not existing.internal_metadata.is_outlier() |  | ||||||
|                 or pdu.internal_metadata.is_outlier() |  | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
|         if already_seen: |  | ||||||
|             logger.debug("Already seen pdu %s", pdu.event_id) |  | ||||||
|             return |  | ||||||
| 
 | 
 | ||||||
|  |         Returns (Deferred): completes with None | ||||||
|  |         Raises: FederationError if the signatures / hash do not match | ||||||
|  |     """ | ||||||
|         # Check signature. |         # Check signature. | ||||||
|         try: |         try: | ||||||
|             pdu = yield self._check_sigs_and_hash(pdu) |             pdu = yield self._check_sigs_and_hash(pdu) | ||||||
| @ -529,143 +521,7 @@ class FederationServer(FederationBase): | |||||||
|                 affected=pdu.event_id, |                 affected=pdu.event_id, | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         state = None |         yield self.handler.on_receive_pdu(origin, pdu, get_missing=True) | ||||||
| 
 |  | ||||||
|         auth_chain = [] |  | ||||||
| 
 |  | ||||||
|         have_seen = yield self.store.have_events( |  | ||||||
|             [ev for ev, _ in pdu.prev_events] |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|         fetch_state = False |  | ||||||
| 
 |  | ||||||
|         # Get missing pdus if necessary. |  | ||||||
|         if not pdu.internal_metadata.is_outlier(): |  | ||||||
|             # We only backfill backwards to the min depth. |  | ||||||
|             min_depth = yield self.handler.get_min_depth_for_context( |  | ||||||
|                 pdu.room_id |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|             logger.debug( |  | ||||||
|                 "_handle_new_pdu min_depth for %s: %d", |  | ||||||
|                 pdu.room_id, min_depth |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|             prevs = {e_id for e_id, _ in pdu.prev_events} |  | ||||||
|             seen = set(have_seen.keys()) |  | ||||||
| 
 |  | ||||||
|             if min_depth and pdu.depth < min_depth: |  | ||||||
|                 # This is so that we don't notify the user about this |  | ||||||
|                 # message, to work around the fact that some events will |  | ||||||
|                 # reference really really old events we really don't want to |  | ||||||
|                 # send to the clients. |  | ||||||
|                 pdu.internal_metadata.outlier = True |  | ||||||
|             elif min_depth and pdu.depth > min_depth: |  | ||||||
|                 if get_missing and prevs - seen: |  | ||||||
|                     # If we're missing stuff, ensure we only fetch stuff one |  | ||||||
|                     # at a time. |  | ||||||
|                     logger.info( |  | ||||||
|                         "Acquiring lock for room %r to fetch %d missing events: %r...", |  | ||||||
|                         pdu.room_id, len(prevs - seen), list(prevs - seen)[:5], |  | ||||||
|                     ) |  | ||||||
|                     with (yield self._room_pdu_linearizer.queue(pdu.room_id)): |  | ||||||
|                         logger.info( |  | ||||||
|                             "Acquired lock for room %r to fetch %d missing events", |  | ||||||
|                             pdu.room_id, len(prevs - seen), |  | ||||||
|                         ) |  | ||||||
| 
 |  | ||||||
|                         # We recalculate seen, since it may have changed. |  | ||||||
|                         have_seen = yield self.store.have_events(prevs) |  | ||||||
|                         seen = set(have_seen.keys()) |  | ||||||
| 
 |  | ||||||
|                         if prevs - seen: |  | ||||||
|                             latest = yield self.store.get_latest_event_ids_in_room( |  | ||||||
|                                 pdu.room_id |  | ||||||
|                             ) |  | ||||||
| 
 |  | ||||||
|                             # We add the prev events that we have seen to the latest |  | ||||||
|                             # list to ensure the remote server doesn't give them to us |  | ||||||
|                             latest = set(latest) |  | ||||||
|                             latest |= seen |  | ||||||
| 
 |  | ||||||
|                             logger.info( |  | ||||||
|                                 "Missing %d events for room %r: %r...", |  | ||||||
|                                 len(prevs - seen), pdu.room_id, list(prevs - seen)[:5] |  | ||||||
|                             ) |  | ||||||
| 
 |  | ||||||
|                             # XXX: we set timeout to 10s to help workaround |  | ||||||
|                             # https://github.com/matrix-org/synapse/issues/1733. |  | ||||||
|                             # The reason is to avoid holding the linearizer lock |  | ||||||
|                             # whilst processing inbound /send transactions, causing |  | ||||||
|                             # FDs to stack up and block other inbound transactions |  | ||||||
|                             # which empirically can currently take up to 30 minutes. |  | ||||||
|                             # |  | ||||||
|                             # N.B. this explicitly disables retry attempts. |  | ||||||
|                             # |  | ||||||
|                             # N.B. this also increases our chances of falling back to |  | ||||||
|                             # fetching fresh state for the room if the missing event |  | ||||||
|                             # can't be found, which slightly reduces our security. |  | ||||||
|                             # it may also increase our DAG extremity count for the room, |  | ||||||
|                             # causing additional state resolution?  See #1760. |  | ||||||
|                             # However, fetching state doesn't hold the linearizer lock |  | ||||||
|                             # apparently. |  | ||||||
|                             # |  | ||||||
|                             # see https://github.com/matrix-org/synapse/pull/1744 |  | ||||||
| 
 |  | ||||||
|                             missing_events = yield self.get_missing_events( |  | ||||||
|                                 origin, |  | ||||||
|                                 pdu.room_id, |  | ||||||
|                                 earliest_events_ids=list(latest), |  | ||||||
|                                 latest_events=[pdu], |  | ||||||
|                                 limit=10, |  | ||||||
|                                 min_depth=min_depth, |  | ||||||
|                                 timeout=10000, |  | ||||||
|                             ) |  | ||||||
| 
 |  | ||||||
|                             # We want to sort these by depth so we process them and |  | ||||||
|                             # tell clients about them in order. |  | ||||||
|                             missing_events.sort(key=lambda x: x.depth) |  | ||||||
| 
 |  | ||||||
|                             for e in missing_events: |  | ||||||
|                                 yield self._handle_new_pdu( |  | ||||||
|                                     origin, |  | ||||||
|                                     e, |  | ||||||
|                                     get_missing=False |  | ||||||
|                                 ) |  | ||||||
| 
 |  | ||||||
|                             have_seen = yield self.store.have_events( |  | ||||||
|                                 [ev for ev, _ in pdu.prev_events] |  | ||||||
|                             ) |  | ||||||
| 
 |  | ||||||
|             prevs = {e_id for e_id, _ in pdu.prev_events} |  | ||||||
|             seen = set(have_seen.keys()) |  | ||||||
|             if prevs - seen: |  | ||||||
|                 logger.info( |  | ||||||
|                     "Still missing %d events for room %r: %r...", |  | ||||||
|                     len(prevs - seen), pdu.room_id, list(prevs - seen)[:5] |  | ||||||
|                 ) |  | ||||||
|                 fetch_state = True |  | ||||||
| 
 |  | ||||||
|         if fetch_state: |  | ||||||
|             # We need to get the state at this event, since we haven't |  | ||||||
|             # processed all the prev events. |  | ||||||
|             logger.debug( |  | ||||||
|                 "_handle_new_pdu getting state for %s", |  | ||||||
|                 pdu.room_id |  | ||||||
|             ) |  | ||||||
|             try: |  | ||||||
|                 state, auth_chain = yield self.get_state_for_room( |  | ||||||
|                     origin, pdu.room_id, pdu.event_id, |  | ||||||
|                 ) |  | ||||||
|             except: |  | ||||||
|                 logger.exception("Failed to get state for event: %s", pdu.event_id) |  | ||||||
| 
 |  | ||||||
|         yield self.handler.on_receive_pdu( |  | ||||||
|             origin, |  | ||||||
|             pdu, |  | ||||||
|             state=state, |  | ||||||
|             auth_chain=auth_chain, |  | ||||||
|         ) |  | ||||||
| 
 | 
 | ||||||
|     def __str__(self): |     def __str__(self): | ||||||
|         return "<ReplicationLayer(%s)>" % self.server_name |         return "<ReplicationLayer(%s)>" % self.server_name | ||||||
|  | |||||||
| @ -54,6 +54,7 @@ class FederationRemoteSendQueue(object): | |||||||
|     def __init__(self, hs): |     def __init__(self, hs): | ||||||
|         self.server_name = hs.hostname |         self.server_name = hs.hostname | ||||||
|         self.clock = hs.get_clock() |         self.clock = hs.get_clock() | ||||||
|  |         self.notifier = hs.get_notifier() | ||||||
| 
 | 
 | ||||||
|         self.presence_map = {} |         self.presence_map = {} | ||||||
|         self.presence_changed = sorteddict() |         self.presence_changed = sorteddict() | ||||||
| @ -186,6 +187,8 @@ class FederationRemoteSendQueue(object): | |||||||
|         else: |         else: | ||||||
|             self.edus[pos] = edu |             self.edus[pos] = edu | ||||||
| 
 | 
 | ||||||
|  |         self.notifier.on_new_replication_data() | ||||||
|  | 
 | ||||||
|     def send_presence(self, destination, states): |     def send_presence(self, destination, states): | ||||||
|         """As per TransactionQueue""" |         """As per TransactionQueue""" | ||||||
|         pos = self._next_pos() |         pos = self._next_pos() | ||||||
| @ -199,16 +202,20 @@ class FederationRemoteSendQueue(object): | |||||||
|             (destination, state.user_id) for state in states |             (destination, state.user_id) for state in states | ||||||
|         ] |         ] | ||||||
| 
 | 
 | ||||||
|  |         self.notifier.on_new_replication_data() | ||||||
|  | 
 | ||||||
|     def send_failure(self, failure, destination): |     def send_failure(self, failure, destination): | ||||||
|         """As per TransactionQueue""" |         """As per TransactionQueue""" | ||||||
|         pos = self._next_pos() |         pos = self._next_pos() | ||||||
| 
 | 
 | ||||||
|         self.failures[pos] = (destination, str(failure)) |         self.failures[pos] = (destination, str(failure)) | ||||||
|  |         self.notifier.on_new_replication_data() | ||||||
| 
 | 
 | ||||||
|     def send_device_messages(self, destination): |     def send_device_messages(self, destination): | ||||||
|         """As per TransactionQueue""" |         """As per TransactionQueue""" | ||||||
|         pos = self._next_pos() |         pos = self._next_pos() | ||||||
|         self.device_messages[pos] = destination |         self.device_messages[pos] = destination | ||||||
|  |         self.notifier.on_new_replication_data() | ||||||
| 
 | 
 | ||||||
|     def get_current_token(self): |     def get_current_token(self): | ||||||
|         return self.pos - 1 |         return self.pos - 1 | ||||||
|  | |||||||
| @ -12,7 +12,7 @@ | |||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| 
 | import datetime | ||||||
| 
 | 
 | ||||||
| from twisted.internet import defer | from twisted.internet import defer | ||||||
| 
 | 
 | ||||||
| @ -22,9 +22,7 @@ from .units import Transaction, Edu | |||||||
| from synapse.api.errors import HttpResponseException | from synapse.api.errors import HttpResponseException | ||||||
| from synapse.util.async import run_on_reactor | from synapse.util.async import run_on_reactor | ||||||
| from synapse.util.logcontext import preserve_context_over_fn | from synapse.util.logcontext import preserve_context_over_fn | ||||||
| from synapse.util.retryutils import ( | from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter | ||||||
|     get_retry_limiter, NotRetryingDestination, |  | ||||||
| ) |  | ||||||
| from synapse.util.metrics import measure_func | from synapse.util.metrics import measure_func | ||||||
| from synapse.types import get_domain_from_id | from synapse.types import get_domain_from_id | ||||||
| from synapse.handlers.presence import format_user_presence_state | from synapse.handlers.presence import format_user_presence_state | ||||||
| @ -99,7 +97,12 @@ class TransactionQueue(object): | |||||||
|         # destination -> list of tuple(failure, deferred) |         # destination -> list of tuple(failure, deferred) | ||||||
|         self.pending_failures_by_dest = {} |         self.pending_failures_by_dest = {} | ||||||
| 
 | 
 | ||||||
|  |         # destination -> stream_id of last successfully sent to-device message. | ||||||
|  |         # NB: may be a long or an int. | ||||||
|         self.last_device_stream_id_by_dest = {} |         self.last_device_stream_id_by_dest = {} | ||||||
|  | 
 | ||||||
|  |         # destination -> stream_id of last successfully sent device list | ||||||
|  |         # update. | ||||||
|         self.last_device_list_stream_id_by_dest = {} |         self.last_device_list_stream_id_by_dest = {} | ||||||
| 
 | 
 | ||||||
|         # HACK to get unique tx id |         # HACK to get unique tx id | ||||||
| @ -300,20 +303,20 @@ class TransactionQueue(object): | |||||||
|             ) |             ) | ||||||
|             return |             return | ||||||
| 
 | 
 | ||||||
|  |         pending_pdus = [] | ||||||
|         try: |         try: | ||||||
|             self.pending_transactions[destination] = 1 |             self.pending_transactions[destination] = 1 | ||||||
| 
 | 
 | ||||||
|  |             # This will throw if we wouldn't retry. We do this here so we fail | ||||||
|  |             # quickly, but we will later check this again in the http client, | ||||||
|  |             # hence why we throw the result away. | ||||||
|  |             yield get_retry_limiter(destination, self.clock, self.store) | ||||||
|  | 
 | ||||||
|             # XXX: what's this for? |             # XXX: what's this for? | ||||||
|             yield run_on_reactor() |             yield run_on_reactor() | ||||||
| 
 | 
 | ||||||
|  |             pending_pdus = [] | ||||||
|             while True: |             while True: | ||||||
|                 limiter = yield get_retry_limiter( |  | ||||||
|                     destination, |  | ||||||
|                     self.clock, |  | ||||||
|                     self.store, |  | ||||||
|                     backoff_on_404=True,  # If we get a 404 the other side has gone |  | ||||||
|                 ) |  | ||||||
| 
 |  | ||||||
|                 device_message_edus, device_stream_id, dev_list_id = ( |                 device_message_edus, device_stream_id, dev_list_id = ( | ||||||
|                     yield self._get_new_device_messages(destination) |                     yield self._get_new_device_messages(destination) | ||||||
|                 ) |                 ) | ||||||
| @ -369,7 +372,6 @@ class TransactionQueue(object): | |||||||
| 
 | 
 | ||||||
|                 success = yield self._send_new_transaction( |                 success = yield self._send_new_transaction( | ||||||
|                     destination, pending_pdus, pending_edus, pending_failures, |                     destination, pending_pdus, pending_edus, pending_failures, | ||||||
|                     limiter=limiter, |  | ||||||
|                 ) |                 ) | ||||||
|                 if success: |                 if success: | ||||||
|                     # Remove the acknowledged device messages from the database |                     # Remove the acknowledged device messages from the database | ||||||
| @ -387,12 +389,24 @@ class TransactionQueue(object): | |||||||
|                     self.last_device_list_stream_id_by_dest[destination] = dev_list_id |                     self.last_device_list_stream_id_by_dest[destination] = dev_list_id | ||||||
|                 else: |                 else: | ||||||
|                     break |                     break | ||||||
|         except NotRetryingDestination: |         except NotRetryingDestination as e: | ||||||
|             logger.debug( |             logger.debug( | ||||||
|                 "TX [%s] not ready for retry yet - " |                 "TX [%s] not ready for retry yet (next retry at %s) - " | ||||||
|                 "dropping transaction for now", |                 "dropping transaction for now", | ||||||
|                 destination, |                 destination, | ||||||
|  |                 datetime.datetime.fromtimestamp( | ||||||
|  |                     (e.retry_last_ts + e.retry_interval) / 1000.0 | ||||||
|  |                 ), | ||||||
|             ) |             ) | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.warn( | ||||||
|  |                 "TX [%s] Failed to send transaction: %s", | ||||||
|  |                 destination, | ||||||
|  |                 e, | ||||||
|  |             ) | ||||||
|  |             for p, _ in pending_pdus: | ||||||
|  |                 logger.info("Failed to send event %s to %s", p.event_id, | ||||||
|  |                             destination) | ||||||
|         finally: |         finally: | ||||||
|             # We want to be *very* sure we delete this after we stop processing |             # We want to be *very* sure we delete this after we stop processing | ||||||
|             self.pending_transactions.pop(destination, None) |             self.pending_transactions.pop(destination, None) | ||||||
| @ -432,7 +446,7 @@ class TransactionQueue(object): | |||||||
|     @measure_func("_send_new_transaction") |     @measure_func("_send_new_transaction") | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def _send_new_transaction(self, destination, pending_pdus, pending_edus, |     def _send_new_transaction(self, destination, pending_pdus, pending_edus, | ||||||
|                               pending_failures, limiter): |                               pending_failures): | ||||||
| 
 | 
 | ||||||
|         # Sort based on the order field |         # Sort based on the order field | ||||||
|         pending_pdus.sort(key=lambda t: t[1]) |         pending_pdus.sort(key=lambda t: t[1]) | ||||||
| @ -442,7 +456,6 @@ class TransactionQueue(object): | |||||||
| 
 | 
 | ||||||
|         success = True |         success = True | ||||||
| 
 | 
 | ||||||
|         try: |  | ||||||
|         logger.debug("TX [%s] _attempt_new_transaction", destination) |         logger.debug("TX [%s] _attempt_new_transaction", destination) | ||||||
| 
 | 
 | ||||||
|         txn_id = str(self._next_txn_id) |         txn_id = str(self._next_txn_id) | ||||||
| @ -483,7 +496,6 @@ class TransactionQueue(object): | |||||||
|             len(failures), |             len(failures), | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|             with limiter: |  | ||||||
|         # Actually send the transaction |         # Actually send the transaction | ||||||
| 
 | 
 | ||||||
|         # FIXME (erikj): This is a bit of a hack to make the Pdu age |         # FIXME (erikj): This is a bit of a hack to make the Pdu age | ||||||
| @ -543,31 +555,5 @@ class TransactionQueue(object): | |||||||
|                     "Failed to send event %s to %s", p.event_id, destination |                     "Failed to send event %s to %s", p.event_id, destination | ||||||
|                 ) |                 ) | ||||||
|             success = False |             success = False | ||||||
|         except RuntimeError as e: |  | ||||||
|             # We capture this here as there as nothing actually listens |  | ||||||
|             # for this finishing functions deferred. |  | ||||||
|             logger.warn( |  | ||||||
|                 "TX [%s] Problem in _attempt_transaction: %s", |  | ||||||
|                 destination, |  | ||||||
|                 e, |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|             success = False |  | ||||||
| 
 |  | ||||||
|             for p in pdus: |  | ||||||
|                 logger.info("Failed to send event %s to %s", p.event_id, destination) |  | ||||||
|         except Exception as e: |  | ||||||
|             # We capture this here as there as nothing actually listens |  | ||||||
|             # for this finishing functions deferred. |  | ||||||
|             logger.warn( |  | ||||||
|                 "TX [%s] Problem in _attempt_transaction: %s", |  | ||||||
|                 destination, |  | ||||||
|                 e, |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|             success = False |  | ||||||
| 
 |  | ||||||
|             for p in pdus: |  | ||||||
|                 logger.info("Failed to send event %s to %s", p.event_id, destination) |  | ||||||
| 
 | 
 | ||||||
|         defer.returnValue(success) |         defer.returnValue(success) | ||||||
|  | |||||||
| @ -163,6 +163,7 @@ class TransportLayerClient(object): | |||||||
|             data=json_data, |             data=json_data, | ||||||
|             json_data_callback=json_data_callback, |             json_data_callback=json_data_callback, | ||||||
|             long_retries=True, |             long_retries=True, | ||||||
|  |             backoff_on_404=True,  # If we get a 404 the other side has gone | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         logger.debug( |         logger.debug( | ||||||
| @ -174,7 +175,8 @@ class TransportLayerClient(object): | |||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     @log_function |     @log_function | ||||||
|     def make_query(self, destination, query_type, args, retry_on_dns_fail): |     def make_query(self, destination, query_type, args, retry_on_dns_fail, | ||||||
|  |                    ignore_backoff=False): | ||||||
|         path = PREFIX + "/query/%s" % query_type |         path = PREFIX + "/query/%s" % query_type | ||||||
| 
 | 
 | ||||||
|         content = yield self.client.get_json( |         content = yield self.client.get_json( | ||||||
| @ -183,6 +185,7 @@ class TransportLayerClient(object): | |||||||
|             args=args, |             args=args, | ||||||
|             retry_on_dns_fail=retry_on_dns_fail, |             retry_on_dns_fail=retry_on_dns_fail, | ||||||
|             timeout=10000, |             timeout=10000, | ||||||
|  |             ignore_backoff=ignore_backoff, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         defer.returnValue(content) |         defer.returnValue(content) | ||||||
| @ -242,6 +245,7 @@ class TransportLayerClient(object): | |||||||
|             destination=destination, |             destination=destination, | ||||||
|             path=path, |             path=path, | ||||||
|             data=content, |             data=content, | ||||||
|  |             ignore_backoff=True, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         defer.returnValue(response) |         defer.returnValue(response) | ||||||
| @ -269,6 +273,7 @@ class TransportLayerClient(object): | |||||||
|             destination=remote_server, |             destination=remote_server, | ||||||
|             path=path, |             path=path, | ||||||
|             args=args, |             args=args, | ||||||
|  |             ignore_backoff=True, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         defer.returnValue(response) |         defer.returnValue(response) | ||||||
|  | |||||||
| @ -1,5 +1,6 @@ | |||||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||||
| # Copyright 2014 - 2016 OpenMarket Ltd | # Copyright 2014 - 2016 OpenMarket Ltd | ||||||
|  | # Copyright 2017 Vector Creations Ltd | ||||||
| # | # | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||||
| @ -47,6 +48,7 @@ class AuthHandler(BaseHandler): | |||||||
|             LoginType.PASSWORD: self._check_password_auth, |             LoginType.PASSWORD: self._check_password_auth, | ||||||
|             LoginType.RECAPTCHA: self._check_recaptcha, |             LoginType.RECAPTCHA: self._check_recaptcha, | ||||||
|             LoginType.EMAIL_IDENTITY: self._check_email_identity, |             LoginType.EMAIL_IDENTITY: self._check_email_identity, | ||||||
|  |             LoginType.MSISDN: self._check_msisdn, | ||||||
|             LoginType.DUMMY: self._check_dummy_auth, |             LoginType.DUMMY: self._check_dummy_auth, | ||||||
|         } |         } | ||||||
|         self.bcrypt_rounds = hs.config.bcrypt_rounds |         self.bcrypt_rounds = hs.config.bcrypt_rounds | ||||||
| @ -307,31 +309,47 @@ class AuthHandler(BaseHandler): | |||||||
|                 defer.returnValue(True) |                 defer.returnValue(True) | ||||||
|         raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) |         raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) | ||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |  | ||||||
|     def _check_email_identity(self, authdict, _): |     def _check_email_identity(self, authdict, _): | ||||||
|  |         return self._check_threepid('email', authdict) | ||||||
|  | 
 | ||||||
|  |     def _check_msisdn(self, authdict, _): | ||||||
|  |         return self._check_threepid('msisdn', authdict) | ||||||
|  | 
 | ||||||
|  |     @defer.inlineCallbacks | ||||||
|  |     def _check_dummy_auth(self, authdict, _): | ||||||
|  |         yield run_on_reactor() | ||||||
|  |         defer.returnValue(True) | ||||||
|  | 
 | ||||||
|  |     @defer.inlineCallbacks | ||||||
|  |     def _check_threepid(self, medium, authdict): | ||||||
|         yield run_on_reactor() |         yield run_on_reactor() | ||||||
| 
 | 
 | ||||||
|         if 'threepid_creds' not in authdict: |         if 'threepid_creds' not in authdict: | ||||||
|             raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM) |             raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM) | ||||||
| 
 | 
 | ||||||
|         threepid_creds = authdict['threepid_creds'] |         threepid_creds = authdict['threepid_creds'] | ||||||
|  | 
 | ||||||
|         identity_handler = self.hs.get_handlers().identity_handler |         identity_handler = self.hs.get_handlers().identity_handler | ||||||
| 
 | 
 | ||||||
|         logger.info("Getting validated threepid. threepidcreds: %r" % (threepid_creds,)) |         logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,)) | ||||||
|         threepid = yield identity_handler.threepid_from_creds(threepid_creds) |         threepid = yield identity_handler.threepid_from_creds(threepid_creds) | ||||||
| 
 | 
 | ||||||
|         if not threepid: |         if not threepid: | ||||||
|             raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) |             raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) | ||||||
| 
 | 
 | ||||||
|  |         if threepid['medium'] != medium: | ||||||
|  |             raise LoginError( | ||||||
|  |                 401, | ||||||
|  |                 "Expecting threepid of type '%s', got '%s'" % ( | ||||||
|  |                     medium, threepid['medium'], | ||||||
|  |                 ), | ||||||
|  |                 errcode=Codes.UNAUTHORIZED | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|         threepid['threepid_creds'] = authdict['threepid_creds'] |         threepid['threepid_creds'] = authdict['threepid_creds'] | ||||||
| 
 | 
 | ||||||
|         defer.returnValue(threepid) |         defer.returnValue(threepid) | ||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |  | ||||||
|     def _check_dummy_auth(self, authdict, _): |  | ||||||
|         yield run_on_reactor() |  | ||||||
|         defer.returnValue(True) |  | ||||||
| 
 |  | ||||||
|     def _get_params_recaptcha(self): |     def _get_params_recaptcha(self): | ||||||
|         return {"public_key": self.hs.config.recaptcha_public_key} |         return {"public_key": self.hs.config.recaptcha_public_key} | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -169,6 +169,40 @@ class DeviceHandler(BaseHandler): | |||||||
| 
 | 
 | ||||||
|         yield self.notify_device_update(user_id, [device_id]) |         yield self.notify_device_update(user_id, [device_id]) | ||||||
| 
 | 
 | ||||||
|  |     @defer.inlineCallbacks | ||||||
|  |     def delete_devices(self, user_id, device_ids): | ||||||
|  |         """ Delete several devices | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             user_id (str): | ||||||
|  |             device_ids (str): The list of device IDs to delete | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             defer.Deferred: | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |         try: | ||||||
|  |             yield self.store.delete_devices(user_id, device_ids) | ||||||
|  |         except errors.StoreError, e: | ||||||
|  |             if e.code == 404: | ||||||
|  |                 # no match | ||||||
|  |                 pass | ||||||
|  |             else: | ||||||
|  |                 raise | ||||||
|  | 
 | ||||||
|  |         # Delete access tokens and e2e keys for each device. Not optimised as it is not | ||||||
|  |         # considered as part of a critical path. | ||||||
|  |         for device_id in device_ids: | ||||||
|  |             yield self.store.user_delete_access_tokens( | ||||||
|  |                 user_id, device_id=device_id, | ||||||
|  |                 delete_refresh_tokens=True, | ||||||
|  |             ) | ||||||
|  |             yield self.store.delete_e2e_keys_by_device( | ||||||
|  |                 user_id=user_id, device_id=device_id | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         yield self.notify_device_update(user_id, device_ids) | ||||||
|  | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def update_device(self, user_id, device_id, content): |     def update_device(self, user_id, device_id, content): | ||||||
|         """ Update the given device |         """ Update the given device | ||||||
| @ -214,8 +248,7 @@ class DeviceHandler(BaseHandler): | |||||||
|             user_id, device_ids, list(hosts) |             user_id, device_ids, list(hosts) | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         rooms = yield self.store.get_rooms_for_user(user_id) |         room_ids = yield self.store.get_rooms_for_user(user_id) | ||||||
|         room_ids = [r.room_id for r in rooms] |  | ||||||
| 
 | 
 | ||||||
|         yield self.notifier.on_new_event( |         yield self.notifier.on_new_event( | ||||||
|             "device_list_key", position, rooms=room_ids, |             "device_list_key", position, rooms=room_ids, | ||||||
| @ -236,8 +269,7 @@ class DeviceHandler(BaseHandler): | |||||||
|             user_id (str) |             user_id (str) | ||||||
|             from_token (StreamToken) |             from_token (StreamToken) | ||||||
|         """ |         """ | ||||||
|         rooms = yield self.store.get_rooms_for_user(user_id) |         room_ids = yield self.store.get_rooms_for_user(user_id) | ||||||
|         room_ids = set(r.room_id for r in rooms) |  | ||||||
| 
 | 
 | ||||||
|         # First we check if any devices have changed |         # First we check if any devices have changed | ||||||
|         changed = yield self.store.get_user_whose_devices_changed( |         changed = yield self.store.get_user_whose_devices_changed( | ||||||
| @ -262,7 +294,7 @@ class DeviceHandler(BaseHandler): | |||||||
|                 # ordering: treat it the same as a new room |                 # ordering: treat it the same as a new room | ||||||
|                 event_ids = [] |                 event_ids = [] | ||||||
| 
 | 
 | ||||||
|             current_state_ids = yield self.state.get_current_state_ids(room_id) |             current_state_ids = yield self.store.get_current_state_ids(room_id) | ||||||
| 
 | 
 | ||||||
|             # special-case for an empty prev state: include all members |             # special-case for an empty prev state: include all members | ||||||
|             # in the changed list |             # in the changed list | ||||||
| @ -313,8 +345,8 @@ class DeviceHandler(BaseHandler): | |||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def user_left_room(self, user, room_id): |     def user_left_room(self, user, room_id): | ||||||
|         user_id = user.to_string() |         user_id = user.to_string() | ||||||
|         rooms = yield self.store.get_rooms_for_user(user_id) |         room_ids = yield self.store.get_rooms_for_user(user_id) | ||||||
|         if not rooms: |         if not room_ids: | ||||||
|             # We no longer share rooms with this user, so we'll no longer |             # We no longer share rooms with this user, so we'll no longer | ||||||
|             # receive device updates. Mark this in DB. |             # receive device updates. Mark this in DB. | ||||||
|             yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id) |             yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id) | ||||||
| @ -370,8 +402,8 @@ class DeviceListEduUpdater(object): | |||||||
|             logger.warning("Got device list update edu for %r from %r", user_id, origin) |             logger.warning("Got device list update edu for %r from %r", user_id, origin) | ||||||
|             return |             return | ||||||
| 
 | 
 | ||||||
|         rooms = yield self.store.get_rooms_for_user(user_id) |         room_ids = yield self.store.get_rooms_for_user(user_id) | ||||||
|         if not rooms: |         if not room_ids: | ||||||
|             # We don't share any rooms with this user. Ignore update, as we |             # We don't share any rooms with this user. Ignore update, as we | ||||||
|             # probably won't get any further updates. |             # probably won't get any further updates. | ||||||
|             return |             return | ||||||
|  | |||||||
| @ -175,6 +175,7 @@ class DirectoryHandler(BaseHandler): | |||||||
|                         "room_alias": room_alias.to_string(), |                         "room_alias": room_alias.to_string(), | ||||||
|                     }, |                     }, | ||||||
|                     retry_on_dns_fail=False, |                     retry_on_dns_fail=False, | ||||||
|  |                     ignore_backoff=True, | ||||||
|                 ) |                 ) | ||||||
|             except CodeMessageException as e: |             except CodeMessageException as e: | ||||||
|                 logging.warn("Error retrieving alias") |                 logging.warn("Error retrieving alias") | ||||||
|  | |||||||
| @ -22,7 +22,7 @@ from twisted.internet import defer | |||||||
| from synapse.api.errors import SynapseError, CodeMessageException | from synapse.api.errors import SynapseError, CodeMessageException | ||||||
| from synapse.types import get_domain_from_id | from synapse.types import get_domain_from_id | ||||||
| from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred | from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred | ||||||
| from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination | from synapse.util.retryutils import NotRetryingDestination | ||||||
| 
 | 
 | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
| 
 | 
 | ||||||
| @ -121,10 +121,6 @@ class E2eKeysHandler(object): | |||||||
|         def do_remote_query(destination): |         def do_remote_query(destination): | ||||||
|             destination_query = remote_queries_not_in_cache[destination] |             destination_query = remote_queries_not_in_cache[destination] | ||||||
|             try: |             try: | ||||||
|                 limiter = yield get_retry_limiter( |  | ||||||
|                     destination, self.clock, self.store |  | ||||||
|                 ) |  | ||||||
|                 with limiter: |  | ||||||
|                 remote_result = yield self.federation.query_client_keys( |                 remote_result = yield self.federation.query_client_keys( | ||||||
|                     destination, |                     destination, | ||||||
|                     {"device_keys": destination_query}, |                     {"device_keys": destination_query}, | ||||||
| @ -239,10 +235,6 @@ class E2eKeysHandler(object): | |||||||
|         def claim_client_keys(destination): |         def claim_client_keys(destination): | ||||||
|             device_keys = remote_queries[destination] |             device_keys = remote_queries[destination] | ||||||
|             try: |             try: | ||||||
|                 limiter = yield get_retry_limiter( |  | ||||||
|                     destination, self.clock, self.store |  | ||||||
|                 ) |  | ||||||
|                 with limiter: |  | ||||||
|                 remote_result = yield self.federation.claim_client_keys( |                 remote_result = yield self.federation.claim_client_keys( | ||||||
|                     destination, |                     destination, | ||||||
|                     {"one_time_keys": device_keys}, |                     {"one_time_keys": device_keys}, | ||||||
| @ -316,7 +308,7 @@ class E2eKeysHandler(object): | |||||||
|         # old access_token without an associated device_id. Either way, we |         # old access_token without an associated device_id. Either way, we | ||||||
|         # need to double-check the device is registered to avoid ending up with |         # need to double-check the device is registered to avoid ending up with | ||||||
|         # keys without a corresponding device. |         # keys without a corresponding device. | ||||||
|         self.device_handler.check_device_registered(user_id, device_id) |         yield self.device_handler.check_device_registered(user_id, device_id) | ||||||
| 
 | 
 | ||||||
|         result = yield self.store.count_e2e_one_time_keys(user_id, device_id) |         result = yield self.store.count_e2e_one_time_keys(user_id, device_id) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -14,6 +14,7 @@ | |||||||
| # limitations under the License. | # limitations under the License. | ||||||
| 
 | 
 | ||||||
| """Contains handlers for federation events.""" | """Contains handlers for federation events.""" | ||||||
|  | import synapse.util.logcontext | ||||||
| from signedjson.key import decode_verify_key_bytes | from signedjson.key import decode_verify_key_bytes | ||||||
| from signedjson.sign import verify_signed_json | from signedjson.sign import verify_signed_json | ||||||
| from unpaddedbase64 import decode_base64 | from unpaddedbase64 import decode_base64 | ||||||
| @ -31,7 +32,7 @@ from synapse.util.logcontext import ( | |||||||
| ) | ) | ||||||
| from synapse.util.metrics import measure_func | from synapse.util.metrics import measure_func | ||||||
| from synapse.util.logutils import log_function | from synapse.util.logutils import log_function | ||||||
| from synapse.util.async import run_on_reactor | from synapse.util.async import run_on_reactor, Linearizer | ||||||
| from synapse.util.frozenutils import unfreeze | from synapse.util.frozenutils import unfreeze | ||||||
| from synapse.crypto.event_signing import ( | from synapse.crypto.event_signing import ( | ||||||
|     compute_event_signature, add_hashes_and_signatures, |     compute_event_signature, add_hashes_and_signatures, | ||||||
| @ -79,29 +80,216 @@ class FederationHandler(BaseHandler): | |||||||
| 
 | 
 | ||||||
|         # When joining a room we need to queue any events for that room up |         # When joining a room we need to queue any events for that room up | ||||||
|         self.room_queues = {} |         self.room_queues = {} | ||||||
|  |         self._room_pdu_linearizer = Linearizer("fed_room_pdu") | ||||||
| 
 | 
 | ||||||
|     @log_function |  | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None): |     @log_function | ||||||
|         """ Called by the ReplicationLayer when we have a new pdu. We need to |     def on_receive_pdu(self, origin, pdu, get_missing=True): | ||||||
|         do auth checks and put it through the StateHandler. |         """ Process a PDU received via a federation /send/ transaction, or | ||||||
|  |         via backfill of missing prev_events | ||||||
| 
 | 
 | ||||||
|         auth_chain and state are None if we already have the necessary state |         Args: | ||||||
|         and prev_events in the db |             origin (str): server which initiated the /send/ transaction. Will | ||||||
|  |                 be used to fetch missing events or state. | ||||||
|  |             pdu (FrozenEvent): received PDU | ||||||
|  |             get_missing (bool): True if we should fetch missing prev_events | ||||||
|  | 
 | ||||||
|  |         Returns (Deferred): completes with None | ||||||
|         """ |         """ | ||||||
|         event = pdu |  | ||||||
| 
 | 
 | ||||||
|         logger.debug("Got event: %s", event.event_id) |         # We reprocess pdus when we have seen them only as outliers | ||||||
|  |         existing = yield self.get_persisted_pdu( | ||||||
|  |             origin, pdu.event_id, do_auth=False | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         # FIXME: Currently we fetch an event again when we already have it | ||||||
|  |         # if it has been marked as an outlier. | ||||||
|  | 
 | ||||||
|  |         already_seen = ( | ||||||
|  |             existing and ( | ||||||
|  |                 not existing.internal_metadata.is_outlier() | ||||||
|  |                 or pdu.internal_metadata.is_outlier() | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         if already_seen: | ||||||
|  |             logger.debug("Already seen pdu %s", pdu.event_id) | ||||||
|  |             return | ||||||
| 
 | 
 | ||||||
|         # If we are currently in the process of joining this room, then we |         # If we are currently in the process of joining this room, then we | ||||||
|         # queue up events for later processing. |         # queue up events for later processing. | ||||||
|         if event.room_id in self.room_queues: |         if pdu.room_id in self.room_queues: | ||||||
|             self.room_queues[event.room_id].append((pdu, origin)) |             logger.info("Ignoring PDU %s for room %s from %s for now; join " | ||||||
|  |                         "in progress", pdu.event_id, pdu.room_id, origin) | ||||||
|  |             self.room_queues[pdu.room_id].append((pdu, origin)) | ||||||
|             return |             return | ||||||
| 
 | 
 | ||||||
|         logger.debug("Processing event: %s", event.event_id) |         state = None | ||||||
| 
 | 
 | ||||||
|         logger.debug("Event: %s", event) |         auth_chain = [] | ||||||
|  | 
 | ||||||
|  |         have_seen = yield self.store.have_events( | ||||||
|  |             [ev for ev, _ in pdu.prev_events] | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         fetch_state = False | ||||||
|  | 
 | ||||||
|  |         # Get missing pdus if necessary. | ||||||
|  |         if not pdu.internal_metadata.is_outlier(): | ||||||
|  |             # We only backfill backwards to the min depth. | ||||||
|  |             min_depth = yield self.get_min_depth_for_context( | ||||||
|  |                 pdu.room_id | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |             logger.debug( | ||||||
|  |                 "_handle_new_pdu min_depth for %s: %d", | ||||||
|  |                 pdu.room_id, min_depth | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |             prevs = {e_id for e_id, _ in pdu.prev_events} | ||||||
|  |             seen = set(have_seen.keys()) | ||||||
|  | 
 | ||||||
|  |             if min_depth and pdu.depth < min_depth: | ||||||
|  |                 # This is so that we don't notify the user about this | ||||||
|  |                 # message, to work around the fact that some events will | ||||||
|  |                 # reference really really old events we really don't want to | ||||||
|  |                 # send to the clients. | ||||||
|  |                 pdu.internal_metadata.outlier = True | ||||||
|  |             elif min_depth and pdu.depth > min_depth: | ||||||
|  |                 if get_missing and prevs - seen: | ||||||
|  |                     # If we're missing stuff, ensure we only fetch stuff one | ||||||
|  |                     # at a time. | ||||||
|  |                     logger.info( | ||||||
|  |                         "Acquiring lock for room %r to fetch %d missing events: %r...", | ||||||
|  |                         pdu.room_id, len(prevs - seen), list(prevs - seen)[:5], | ||||||
|  |                     ) | ||||||
|  |                     with (yield self._room_pdu_linearizer.queue(pdu.room_id)): | ||||||
|  |                         logger.info( | ||||||
|  |                             "Acquired lock for room %r to fetch %d missing events", | ||||||
|  |                             pdu.room_id, len(prevs - seen), | ||||||
|  |                         ) | ||||||
|  | 
 | ||||||
|  |                         yield self._get_missing_events_for_pdu( | ||||||
|  |                             origin, pdu, prevs, min_depth | ||||||
|  |                         ) | ||||||
|  | 
 | ||||||
|  |             prevs = {e_id for e_id, _ in pdu.prev_events} | ||||||
|  |             seen = set(have_seen.keys()) | ||||||
|  |             if prevs - seen: | ||||||
|  |                 logger.info( | ||||||
|  |                     "Still missing %d events for room %r: %r...", | ||||||
|  |                     len(prevs - seen), pdu.room_id, list(prevs - seen)[:5] | ||||||
|  |                 ) | ||||||
|  |                 fetch_state = True | ||||||
|  | 
 | ||||||
|  |         if fetch_state: | ||||||
|  |             # We need to get the state at this event, since we haven't | ||||||
|  |             # processed all the prev events. | ||||||
|  |             logger.debug( | ||||||
|  |                 "_handle_new_pdu getting state for %s", | ||||||
|  |                 pdu.room_id | ||||||
|  |             ) | ||||||
|  |             try: | ||||||
|  |                 state, auth_chain = yield self.replication_layer.get_state_for_room( | ||||||
|  |                     origin, pdu.room_id, pdu.event_id, | ||||||
|  |                 ) | ||||||
|  |             except: | ||||||
|  |                 logger.exception("Failed to get state for event: %s", pdu.event_id) | ||||||
|  | 
 | ||||||
|  |         yield self._process_received_pdu( | ||||||
|  |             origin, | ||||||
|  |             pdu, | ||||||
|  |             state=state, | ||||||
|  |             auth_chain=auth_chain, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     @defer.inlineCallbacks | ||||||
|  |     def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): | ||||||
|  |         """ | ||||||
|  |         Args: | ||||||
|  |             origin (str): Origin of the pdu. Will be called to get the missing events | ||||||
|  |             pdu: received pdu | ||||||
|  |             prevs (str[]): List of event ids which we are missing | ||||||
|  |             min_depth (int): Minimum depth of events to return. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             Deferred<dict(str, str?)>: updated have_seen dictionary | ||||||
|  |         """ | ||||||
|  |         # We recalculate seen, since it may have changed. | ||||||
|  |         have_seen = yield self.store.have_events(prevs) | ||||||
|  |         seen = set(have_seen.keys()) | ||||||
|  | 
 | ||||||
|  |         if not prevs - seen: | ||||||
|  |             # nothing left to do | ||||||
|  |             defer.returnValue(have_seen) | ||||||
|  | 
 | ||||||
|  |         latest = yield self.store.get_latest_event_ids_in_room( | ||||||
|  |             pdu.room_id | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         # We add the prev events that we have seen to the latest | ||||||
|  |         # list to ensure the remote server doesn't give them to us | ||||||
|  |         latest = set(latest) | ||||||
|  |         latest |= seen | ||||||
|  | 
 | ||||||
|  |         logger.info( | ||||||
|  |             "Missing %d events for room %r: %r...", | ||||||
|  |             len(prevs - seen), pdu.room_id, list(prevs - seen)[:5] | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         # XXX: we set timeout to 10s to help workaround | ||||||
|  |         # https://github.com/matrix-org/synapse/issues/1733. | ||||||
|  |         # The reason is to avoid holding the linearizer lock | ||||||
|  |         # whilst processing inbound /send transactions, causing | ||||||
|  |         # FDs to stack up and block other inbound transactions | ||||||
|  |         # which empirically can currently take up to 30 minutes. | ||||||
|  |         # | ||||||
|  |         # N.B. this explicitly disables retry attempts. | ||||||
|  |         # | ||||||
|  |         # N.B. this also increases our chances of falling back to | ||||||
|  |         # fetching fresh state for the room if the missing event | ||||||
|  |         # can't be found, which slightly reduces our security. | ||||||
|  |         # it may also increase our DAG extremity count for the room, | ||||||
|  |         # causing additional state resolution?  See #1760. | ||||||
|  |         # However, fetching state doesn't hold the linearizer lock | ||||||
|  |         # apparently. | ||||||
|  |         # | ||||||
|  |         # see https://github.com/matrix-org/synapse/pull/1744 | ||||||
|  | 
 | ||||||
|  |         missing_events = yield self.replication_layer.get_missing_events( | ||||||
|  |             origin, | ||||||
|  |             pdu.room_id, | ||||||
|  |             earliest_events_ids=list(latest), | ||||||
|  |             latest_events=[pdu], | ||||||
|  |             limit=10, | ||||||
|  |             min_depth=min_depth, | ||||||
|  |             timeout=10000, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         # We want to sort these by depth so we process them and | ||||||
|  |         # tell clients about them in order. | ||||||
|  |         missing_events.sort(key=lambda x: x.depth) | ||||||
|  | 
 | ||||||
|  |         for e in missing_events: | ||||||
|  |             yield self.on_receive_pdu( | ||||||
|  |                 origin, | ||||||
|  |                 e, | ||||||
|  |                 get_missing=False | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         have_seen = yield self.store.have_events( | ||||||
|  |             [ev for ev, _ in pdu.prev_events] | ||||||
|  |         ) | ||||||
|  |         defer.returnValue(have_seen) | ||||||
|  | 
 | ||||||
|  |     @log_function | ||||||
|  |     @defer.inlineCallbacks | ||||||
|  |     def _process_received_pdu(self, origin, pdu, state, auth_chain): | ||||||
|  |         """ Called when we have a new pdu. We need to do auth checks and put it | ||||||
|  |         through the StateHandler. | ||||||
|  |         """ | ||||||
|  |         event = pdu | ||||||
|  | 
 | ||||||
|  |         logger.debug("Processing event: %s", event) | ||||||
| 
 | 
 | ||||||
|         # FIXME (erikj): Awful hack to make the case where we are not currently |         # FIXME (erikj): Awful hack to make the case where we are not currently | ||||||
|         # in the room work |         # in the room work | ||||||
| @ -670,8 +858,6 @@ class FederationHandler(BaseHandler): | |||||||
|         """ |         """ | ||||||
|         logger.debug("Joining %s to %s", joinee, room_id) |         logger.debug("Joining %s to %s", joinee, room_id) | ||||||
| 
 | 
 | ||||||
|         yield self.store.clean_room_for_join(room_id) |  | ||||||
| 
 |  | ||||||
|         origin, event = yield self._make_and_verify_event( |         origin, event = yield self._make_and_verify_event( | ||||||
|             target_hosts, |             target_hosts, | ||||||
|             room_id, |             room_id, | ||||||
| @ -680,7 +866,15 @@ class FederationHandler(BaseHandler): | |||||||
|             content, |             content, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  |         # This shouldn't happen, because the RoomMemberHandler has a | ||||||
|  |         # linearizer lock which only allows one operation per user per room | ||||||
|  |         # at a time - so this is just paranoia. | ||||||
|  |         assert (room_id not in self.room_queues) | ||||||
|  | 
 | ||||||
|         self.room_queues[room_id] = [] |         self.room_queues[room_id] = [] | ||||||
|  | 
 | ||||||
|  |         yield self.store.clean_room_for_join(room_id) | ||||||
|  | 
 | ||||||
|         handled_events = set() |         handled_events = set() | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
| @ -733,17 +927,36 @@ class FederationHandler(BaseHandler): | |||||||
|             room_queue = self.room_queues[room_id] |             room_queue = self.room_queues[room_id] | ||||||
|             del self.room_queues[room_id] |             del self.room_queues[room_id] | ||||||
| 
 | 
 | ||||||
|             for p, origin in room_queue: |             # we don't need to wait for the queued events to be processed - | ||||||
|                 if p.event_id in handled_events: |             # it's just a best-effort thing at this point. We do want to do | ||||||
|                     continue |             # them roughly in order, though, otherwise we'll end up making | ||||||
|  |             # lots of requests for missing prev_events which we do actually | ||||||
|  |             # have. Hence we fire off the deferred, but don't wait for it. | ||||||
| 
 | 
 | ||||||
|                 try: |             synapse.util.logcontext.preserve_fn(self._handle_queued_pdus)( | ||||||
|                     self.on_receive_pdu(origin, p) |                 room_queue | ||||||
|                 except: |             ) | ||||||
|                     logger.exception("Couldn't handle pdu") |  | ||||||
| 
 | 
 | ||||||
|         defer.returnValue(True) |         defer.returnValue(True) | ||||||
| 
 | 
 | ||||||
|  |     @defer.inlineCallbacks | ||||||
|  |     def _handle_queued_pdus(self, room_queue): | ||||||
|  |         """Process PDUs which got queued up while we were busy send_joining. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             room_queue (list[FrozenEvent, str]): list of PDUs to be processed | ||||||
|  |                 and the servers that sent them | ||||||
|  |         """ | ||||||
|  |         for p, origin in room_queue: | ||||||
|  |             try: | ||||||
|  |                 logger.info("Processing queued PDU %s which was received " | ||||||
|  |                             "while we were joining %s", p.event_id, p.room_id) | ||||||
|  |                 yield self.on_receive_pdu(origin, p) | ||||||
|  |             except Exception as e: | ||||||
|  |                 logger.warn( | ||||||
|  |                     "Error handling queued PDU %s from %s: %s", | ||||||
|  |                     p.event_id, origin, e) | ||||||
|  | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     @log_function |     @log_function | ||||||
|     def on_make_join_request(self, room_id, user_id): |     def on_make_join_request(self, room_id, user_id): | ||||||
| @ -791,9 +1004,19 @@ class FederationHandler(BaseHandler): | |||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         event.internal_metadata.outlier = False |         event.internal_metadata.outlier = False | ||||||
|         # Send this event on behalf of the origin server since they may not |         # Send this event on behalf of the origin server. | ||||||
|         # have an up to data view of the state of the room at this event so |         # | ||||||
|         # will not know which servers to send the event to. |         # The reasons we have the destination server rather than the origin | ||||||
|  |         # server send it are slightly mysterious: the origin server should have | ||||||
|  |         # all the neccessary state once it gets the response to the send_join, | ||||||
|  |         # so it could send the event itself if it wanted to. It may be that | ||||||
|  |         # doing it this way reduces failure modes, or avoids certain attacks | ||||||
|  |         # where a new server selectively tells a subset of the federation that | ||||||
|  |         # it has joined. | ||||||
|  |         # | ||||||
|  |         # The fact is that, as of the current writing, Synapse doesn't send out | ||||||
|  |         # the join event over federation after joining, and changing it now | ||||||
|  |         # would introduce the danger of backwards-compatibility problems. | ||||||
|         event.internal_metadata.send_on_behalf_of = origin |         event.internal_metadata.send_on_behalf_of = origin | ||||||
| 
 | 
 | ||||||
|         context, event_stream_id, max_stream_id = yield self._handle_new_event( |         context, event_stream_id, max_stream_id = yield self._handle_new_event( | ||||||
| @ -878,15 +1101,15 @@ class FederationHandler(BaseHandler): | |||||||
|                 user_id, |                 user_id, | ||||||
|                 "leave" |                 "leave" | ||||||
|             ) |             ) | ||||||
|             signed_event = self._sign_event(event) |             event = self._sign_event(event) | ||||||
|         except SynapseError: |         except SynapseError: | ||||||
|             raise |             raise | ||||||
|         except CodeMessageException as e: |         except CodeMessageException as e: | ||||||
|             logger.warn("Failed to reject invite: %s", e) |             logger.warn("Failed to reject invite: %s", e) | ||||||
|             raise SynapseError(500, "Failed to reject invite") |             raise SynapseError(500, "Failed to reject invite") | ||||||
| 
 | 
 | ||||||
|         # Try the host we successfully got a response to /make_join/ |         # Try the host that we succesfully called /make_leave/ on first for | ||||||
|         # request first. |         # the /send_leave/ request. | ||||||
|         try: |         try: | ||||||
|             target_hosts.remove(origin) |             target_hosts.remove(origin) | ||||||
|             target_hosts.insert(0, origin) |             target_hosts.insert(0, origin) | ||||||
| @ -896,7 +1119,7 @@ class FederationHandler(BaseHandler): | |||||||
|         try: |         try: | ||||||
|             yield self.replication_layer.send_leave( |             yield self.replication_layer.send_leave( | ||||||
|                 target_hosts, |                 target_hosts, | ||||||
|                 signed_event |                 event | ||||||
|             ) |             ) | ||||||
|         except SynapseError: |         except SynapseError: | ||||||
|             raise |             raise | ||||||
| @ -1325,7 +1548,17 @@ class FederationHandler(BaseHandler): | |||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def _prep_event(self, origin, event, state=None, auth_events=None): |     def _prep_event(self, origin, event, state=None, auth_events=None): | ||||||
|  |         """ | ||||||
| 
 | 
 | ||||||
|  |         Args: | ||||||
|  |             origin: | ||||||
|  |             event: | ||||||
|  |             state: | ||||||
|  |             auth_events: | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             Deferred, which resolves to synapse.events.snapshot.EventContext | ||||||
|  |         """ | ||||||
|         context = yield self.state_handler.compute_event_context( |         context = yield self.state_handler.compute_event_context( | ||||||
|             event, old_state=state, |             event, old_state=state, | ||||||
|         ) |         ) | ||||||
|  | |||||||
| @ -1,5 +1,6 @@ | |||||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||||
| # Copyright 2015, 2016 OpenMarket Ltd | # Copyright 2015, 2016 OpenMarket Ltd | ||||||
|  | # Copyright 2017 Vector Creations Ltd | ||||||
| # | # | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||||
| @ -150,7 +151,7 @@ class IdentityHandler(BaseHandler): | |||||||
|         params.update(kwargs) |         params.update(kwargs) | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|             data = yield self.http_client.post_urlencoded_get_json( |             data = yield self.http_client.post_json_get_json( | ||||||
|                 "https://%s%s" % ( |                 "https://%s%s" % ( | ||||||
|                     id_server, |                     id_server, | ||||||
|                     "/_matrix/identity/api/v1/validate/email/requestToken" |                     "/_matrix/identity/api/v1/validate/email/requestToken" | ||||||
| @ -161,3 +162,37 @@ class IdentityHandler(BaseHandler): | |||||||
|         except CodeMessageException as e: |         except CodeMessageException as e: | ||||||
|             logger.info("Proxied requestToken failed: %r", e) |             logger.info("Proxied requestToken failed: %r", e) | ||||||
|             raise e |             raise e | ||||||
|  | 
 | ||||||
|  |     @defer.inlineCallbacks | ||||||
|  |     def requestMsisdnToken( | ||||||
|  |             self, id_server, country, phone_number, | ||||||
|  |             client_secret, send_attempt, **kwargs | ||||||
|  |     ): | ||||||
|  |         yield run_on_reactor() | ||||||
|  | 
 | ||||||
|  |         if not self._should_trust_id_server(id_server): | ||||||
|  |             raise SynapseError( | ||||||
|  |                 400, "Untrusted ID server '%s'" % id_server, | ||||||
|  |                 Codes.SERVER_NOT_TRUSTED | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         params = { | ||||||
|  |             'country': country, | ||||||
|  |             'phone_number': phone_number, | ||||||
|  |             'client_secret': client_secret, | ||||||
|  |             'send_attempt': send_attempt, | ||||||
|  |         } | ||||||
|  |         params.update(kwargs) | ||||||
|  | 
 | ||||||
|  |         try: | ||||||
|  |             data = yield self.http_client.post_json_get_json( | ||||||
|  |                 "https://%s%s" % ( | ||||||
|  |                     id_server, | ||||||
|  |                     "/_matrix/identity/api/v1/validate/msisdn/requestToken" | ||||||
|  |                 ), | ||||||
|  |                 params | ||||||
|  |             ) | ||||||
|  |             defer.returnValue(data) | ||||||
|  |         except CodeMessageException as e: | ||||||
|  |             logger.info("Proxied requestToken failed: %r", e) | ||||||
|  |             raise e | ||||||
|  | |||||||
| @ -19,6 +19,7 @@ from synapse.api.constants import EventTypes, Membership | |||||||
| from synapse.api.errors import AuthError, Codes | from synapse.api.errors import AuthError, Codes | ||||||
| from synapse.events.utils import serialize_event | from synapse.events.utils import serialize_event | ||||||
| from synapse.events.validator import EventValidator | from synapse.events.validator import EventValidator | ||||||
|  | from synapse.handlers.presence import format_user_presence_state | ||||||
| from synapse.streams.config import PaginationConfig | from synapse.streams.config import PaginationConfig | ||||||
| from synapse.types import ( | from synapse.types import ( | ||||||
|     UserID, StreamToken, |     UserID, StreamToken, | ||||||
| @ -225,9 +226,17 @@ class InitialSyncHandler(BaseHandler): | |||||||
|                 "content": content, |                 "content": content, | ||||||
|             }) |             }) | ||||||
| 
 | 
 | ||||||
|  |         now = self.clock.time_msec() | ||||||
|  | 
 | ||||||
|         ret = { |         ret = { | ||||||
|             "rooms": rooms_ret, |             "rooms": rooms_ret, | ||||||
|             "presence": presence, |             "presence": [ | ||||||
|  |                 { | ||||||
|  |                     "type": "m.presence", | ||||||
|  |                     "content": format_user_presence_state(event, now), | ||||||
|  |                 } | ||||||
|  |                 for event in presence | ||||||
|  |             ], | ||||||
|             "account_data": account_data_events, |             "account_data": account_data_events, | ||||||
|             "receipts": receipt, |             "receipts": receipt, | ||||||
|             "end": now_token.to_string(), |             "end": now_token.to_string(), | ||||||
|  | |||||||
| @ -29,6 +29,7 @@ from synapse.api.errors import SynapseError | |||||||
| from synapse.api.constants import PresenceState | from synapse.api.constants import PresenceState | ||||||
| from synapse.storage.presence import UserPresenceState | from synapse.storage.presence import UserPresenceState | ||||||
| 
 | 
 | ||||||
|  | from synapse.util.caches.descriptors import cachedInlineCallbacks | ||||||
| from synapse.util.logcontext import preserve_fn | from synapse.util.logcontext import preserve_fn | ||||||
| from synapse.util.logutils import log_function | from synapse.util.logutils import log_function | ||||||
| from synapse.util.metrics import Measure | from synapse.util.metrics import Measure | ||||||
| @ -556,9 +557,9 @@ class PresenceHandler(object): | |||||||
|         room_ids_to_states = {} |         room_ids_to_states = {} | ||||||
|         users_to_states = {} |         users_to_states = {} | ||||||
|         for state in states: |         for state in states: | ||||||
|             events = yield self.store.get_rooms_for_user(state.user_id) |             room_ids = yield self.store.get_rooms_for_user(state.user_id) | ||||||
|             for e in events: |             for room_id in room_ids: | ||||||
|                 room_ids_to_states.setdefault(e.room_id, []).append(state) |                 room_ids_to_states.setdefault(room_id, []).append(state) | ||||||
| 
 | 
 | ||||||
|             plist = yield self.store.get_presence_list_observers_accepted(state.user_id) |             plist = yield self.store.get_presence_list_observers_accepted(state.user_id) | ||||||
|             for u in plist: |             for u in plist: | ||||||
| @ -574,8 +575,7 @@ class PresenceHandler(object): | |||||||
|                 if not local_states: |                 if not local_states: | ||||||
|                     continue |                     continue | ||||||
| 
 | 
 | ||||||
|                 users = yield self.store.get_users_in_room(room_id) |                 hosts = yield self.store.get_hosts_in_room(room_id) | ||||||
|                 hosts = set(get_domain_from_id(u) for u in users) |  | ||||||
| 
 | 
 | ||||||
|                 for host in hosts: |                 for host in hosts: | ||||||
|                     hosts_to_states.setdefault(host, []).extend(local_states) |                     hosts_to_states.setdefault(host, []).extend(local_states) | ||||||
| @ -719,9 +719,7 @@ class PresenceHandler(object): | |||||||
|                 for state in updates |                 for state in updates | ||||||
|             ]) |             ]) | ||||||
|         else: |         else: | ||||||
|             defer.returnValue([ |             defer.returnValue(updates) | ||||||
|                 format_user_presence_state(state, now) for state in updates |  | ||||||
|             ]) |  | ||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def set_state(self, target_user, state, ignore_status_msg=False): |     def set_state(self, target_user, state, ignore_status_msg=False): | ||||||
| @ -795,6 +793,9 @@ class PresenceHandler(object): | |||||||
|             as_event=False, |             as_event=False, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  |         now = self.clock.time_msec() | ||||||
|  |         results[:] = [format_user_presence_state(r, now) for r in results] | ||||||
|  | 
 | ||||||
|         is_accepted = { |         is_accepted = { | ||||||
|             row["observed_user_id"]: row["accepted"] for row in presence_list |             row["observed_user_id"]: row["accepted"] for row in presence_list | ||||||
|         } |         } | ||||||
| @ -847,6 +848,7 @@ class PresenceHandler(object): | |||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|             state_dict = yield self.get_state(observed_user, as_event=False) |             state_dict = yield self.get_state(observed_user, as_event=False) | ||||||
|  |             state_dict = format_user_presence_state(state_dict, self.clock.time_msec()) | ||||||
| 
 | 
 | ||||||
|             self.federation.send_edu( |             self.federation.send_edu( | ||||||
|                 destination=observer_user.domain, |                 destination=observer_user.domain, | ||||||
| @ -910,11 +912,12 @@ class PresenceHandler(object): | |||||||
|     def is_visible(self, observed_user, observer_user): |     def is_visible(self, observed_user, observer_user): | ||||||
|         """Returns whether a user can see another user's presence. |         """Returns whether a user can see another user's presence. | ||||||
|         """ |         """ | ||||||
|         observer_rooms = yield self.store.get_rooms_for_user(observer_user.to_string()) |         observer_room_ids = yield self.store.get_rooms_for_user( | ||||||
|         observed_rooms = yield self.store.get_rooms_for_user(observed_user.to_string()) |             observer_user.to_string() | ||||||
| 
 |         ) | ||||||
|         observer_room_ids = set(r.room_id for r in observer_rooms) |         observed_room_ids = yield self.store.get_rooms_for_user( | ||||||
|         observed_room_ids = set(r.room_id for r in observed_rooms) |             observed_user.to_string() | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         if observer_room_ids & observed_room_ids: |         if observer_room_ids & observed_room_ids: | ||||||
|             defer.returnValue(True) |             defer.returnValue(True) | ||||||
| @ -979,14 +982,18 @@ def should_notify(old_state, new_state): | |||||||
|     return False |     return False | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def format_user_presence_state(state, now): | def format_user_presence_state(state, now, include_user_id=True): | ||||||
|     """Convert UserPresenceState to a format that can be sent down to clients |     """Convert UserPresenceState to a format that can be sent down to clients | ||||||
|     and to other servers. |     and to other servers. | ||||||
|  | 
 | ||||||
|  |     The "user_id" is optional so that this function can be used to format presence | ||||||
|  |     updates for client /sync responses and for federation /send requests. | ||||||
|     """ |     """ | ||||||
|     content = { |     content = { | ||||||
|         "presence": state.state, |         "presence": state.state, | ||||||
|         "user_id": state.user_id, |  | ||||||
|     } |     } | ||||||
|  |     if include_user_id: | ||||||
|  |         content["user_id"] = state.user_id | ||||||
|     if state.last_active_ts: |     if state.last_active_ts: | ||||||
|         content["last_active_ago"] = now - state.last_active_ts |         content["last_active_ago"] = now - state.last_active_ts | ||||||
|     if state.status_msg and state.state != PresenceState.OFFLINE: |     if state.status_msg and state.state != PresenceState.OFFLINE: | ||||||
| @ -1025,7 +1032,6 @@ class PresenceEventSource(object): | |||||||
|         # sending down the rare duplicate is not a concern. |         # sending down the rare duplicate is not a concern. | ||||||
| 
 | 
 | ||||||
|         with Measure(self.clock, "presence.get_new_events"): |         with Measure(self.clock, "presence.get_new_events"): | ||||||
|             user_id = user.to_string() |  | ||||||
|             if from_key is not None: |             if from_key is not None: | ||||||
|                 from_key = int(from_key) |                 from_key = int(from_key) | ||||||
| 
 | 
 | ||||||
| @ -1034,18 +1040,7 @@ class PresenceEventSource(object): | |||||||
| 
 | 
 | ||||||
|             max_token = self.store.get_current_presence_token() |             max_token = self.store.get_current_presence_token() | ||||||
| 
 | 
 | ||||||
|             plist = yield self.store.get_presence_list_accepted(user.localpart) |             users_interested_in = yield self._get_interested_in(user, explicit_room_id) | ||||||
|             users_interested_in = set(row["observed_user_id"] for row in plist) |  | ||||||
|             users_interested_in.add(user_id)  # So that we receive our own presence |  | ||||||
| 
 |  | ||||||
|             users_who_share_room = yield self.store.get_users_who_share_room_with_user( |  | ||||||
|                 user_id |  | ||||||
|             ) |  | ||||||
|             users_interested_in.update(users_who_share_room) |  | ||||||
| 
 |  | ||||||
|             if explicit_room_id: |  | ||||||
|                 user_ids = yield self.store.get_users_in_room(explicit_room_id) |  | ||||||
|                 users_interested_in.update(user_ids) |  | ||||||
| 
 | 
 | ||||||
|             user_ids_changed = set() |             user_ids_changed = set() | ||||||
|             changed = None |             changed = None | ||||||
| @ -1073,15 +1068,12 @@ class PresenceEventSource(object): | |||||||
| 
 | 
 | ||||||
|             updates = yield presence.current_state_for_users(user_ids_changed) |             updates = yield presence.current_state_for_users(user_ids_changed) | ||||||
| 
 | 
 | ||||||
|         now = self.clock.time_msec() |         if include_offline: | ||||||
| 
 |             defer.returnValue((updates.values(), max_token)) | ||||||
|  |         else: | ||||||
|             defer.returnValue(([ |             defer.returnValue(([ | ||||||
|             { |                 s for s in updates.itervalues() | ||||||
|                 "type": "m.presence", |                 if s.state != PresenceState.OFFLINE | ||||||
|                 "content": format_user_presence_state(s, now), |  | ||||||
|             } |  | ||||||
|             for s in updates.values() |  | ||||||
|             if include_offline or s.state != PresenceState.OFFLINE |  | ||||||
|             ], max_token)) |             ], max_token)) | ||||||
| 
 | 
 | ||||||
|     def get_current_key(self): |     def get_current_key(self): | ||||||
| @ -1090,6 +1082,31 @@ class PresenceEventSource(object): | |||||||
|     def get_pagination_rows(self, user, pagination_config, key): |     def get_pagination_rows(self, user, pagination_config, key): | ||||||
|         return self.get_new_events(user, from_key=None, include_offline=False) |         return self.get_new_events(user, from_key=None, include_offline=False) | ||||||
| 
 | 
 | ||||||
|  |     @cachedInlineCallbacks(num_args=2, cache_context=True) | ||||||
|  |     def _get_interested_in(self, user, explicit_room_id, cache_context): | ||||||
|  |         """Returns the set of users that the given user should see presence | ||||||
|  |         updates for | ||||||
|  |         """ | ||||||
|  |         user_id = user.to_string() | ||||||
|  |         plist = yield self.store.get_presence_list_accepted( | ||||||
|  |             user.localpart, on_invalidate=cache_context.invalidate, | ||||||
|  |         ) | ||||||
|  |         users_interested_in = set(row["observed_user_id"] for row in plist) | ||||||
|  |         users_interested_in.add(user_id)  # So that we receive our own presence | ||||||
|  | 
 | ||||||
|  |         users_who_share_room = yield self.store.get_users_who_share_room_with_user( | ||||||
|  |             user_id, on_invalidate=cache_context.invalidate, | ||||||
|  |         ) | ||||||
|  |         users_interested_in.update(users_who_share_room) | ||||||
|  | 
 | ||||||
|  |         if explicit_room_id: | ||||||
|  |             user_ids = yield self.store.get_users_in_room( | ||||||
|  |                 explicit_room_id, on_invalidate=cache_context.invalidate, | ||||||
|  |             ) | ||||||
|  |             users_interested_in.update(user_ids) | ||||||
|  | 
 | ||||||
|  |         defer.returnValue(users_interested_in) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now): | def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now): | ||||||
|     """Checks the presence of users that have timed out and updates as |     """Checks the presence of users that have timed out and updates as | ||||||
| @ -1157,7 +1174,10 @@ def handle_timeout(state, is_mine, syncing_user_ids, now): | |||||||
|         # If there are have been no sync for a while (and none ongoing), |         # If there are have been no sync for a while (and none ongoing), | ||||||
|         # set presence to offline |         # set presence to offline | ||||||
|         if user_id not in syncing_user_ids: |         if user_id not in syncing_user_ids: | ||||||
|             if now - state.last_user_sync_ts > SYNC_ONLINE_TIMEOUT: |             # If the user has done something recently but hasn't synced, | ||||||
|  |             # don't set them as offline. | ||||||
|  |             sync_or_active = max(state.last_user_sync_ts, state.last_active_ts) | ||||||
|  |             if now - sync_or_active > SYNC_ONLINE_TIMEOUT: | ||||||
|                 state = state.copy_and_replace( |                 state = state.copy_and_replace( | ||||||
|                     state=PresenceState.OFFLINE, |                     state=PresenceState.OFFLINE, | ||||||
|                     status_msg=None, |                     status_msg=None, | ||||||
|  | |||||||
| @ -52,7 +52,8 @@ class ProfileHandler(BaseHandler): | |||||||
|                     args={ |                     args={ | ||||||
|                         "user_id": target_user.to_string(), |                         "user_id": target_user.to_string(), | ||||||
|                         "field": "displayname", |                         "field": "displayname", | ||||||
|                     } |                     }, | ||||||
|  |                     ignore_backoff=True, | ||||||
|                 ) |                 ) | ||||||
|             except CodeMessageException as e: |             except CodeMessageException as e: | ||||||
|                 if e.code != 404: |                 if e.code != 404: | ||||||
| @ -99,7 +100,8 @@ class ProfileHandler(BaseHandler): | |||||||
|                     args={ |                     args={ | ||||||
|                         "user_id": target_user.to_string(), |                         "user_id": target_user.to_string(), | ||||||
|                         "field": "avatar_url", |                         "field": "avatar_url", | ||||||
|                     } |                     }, | ||||||
|  |                     ignore_backoff=True, | ||||||
|                 ) |                 ) | ||||||
|             except CodeMessageException as e: |             except CodeMessageException as e: | ||||||
|                 if e.code != 404: |                 if e.code != 404: | ||||||
| @ -156,11 +158,11 @@ class ProfileHandler(BaseHandler): | |||||||
| 
 | 
 | ||||||
|         self.ratelimit(requester) |         self.ratelimit(requester) | ||||||
| 
 | 
 | ||||||
|         joins = yield self.store.get_rooms_for_user( |         room_ids = yield self.store.get_rooms_for_user( | ||||||
|             user.to_string(), |             user.to_string(), | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         for j in joins: |         for room_id in room_ids: | ||||||
|             handler = self.hs.get_handlers().room_member_handler |             handler = self.hs.get_handlers().room_member_handler | ||||||
|             try: |             try: | ||||||
|                 # Assume the user isn't a guest because we don't let guests set |                 # Assume the user isn't a guest because we don't let guests set | ||||||
| @ -171,12 +173,12 @@ class ProfileHandler(BaseHandler): | |||||||
|                 yield handler.update_membership( |                 yield handler.update_membership( | ||||||
|                     requester, |                     requester, | ||||||
|                     user, |                     user, | ||||||
|                     j.room_id, |                     room_id, | ||||||
|                     "join",  # We treat a profile update like a join. |                     "join",  # We treat a profile update like a join. | ||||||
|                     ratelimit=False,  # Try to hide that these events aren't atomic. |                     ratelimit=False,  # Try to hide that these events aren't atomic. | ||||||
|                 ) |                 ) | ||||||
|             except Exception as e: |             except Exception as e: | ||||||
|                 logger.warn( |                 logger.warn( | ||||||
|                     "Failed to update join event for room %s - %s", |                     "Failed to update join event for room %s - %s", | ||||||
|                     j.room_id, str(e.message) |                     room_id, str(e.message) | ||||||
|                 ) |                 ) | ||||||
|  | |||||||
| @ -210,10 +210,9 @@ class ReceiptEventSource(object): | |||||||
|         else: |         else: | ||||||
|             from_key = None |             from_key = None | ||||||
| 
 | 
 | ||||||
|         rooms = yield self.store.get_rooms_for_user(user.to_string()) |         room_ids = yield self.store.get_rooms_for_user(user.to_string()) | ||||||
|         rooms = [room.room_id for room in rooms] |  | ||||||
|         events = yield self.store.get_linearized_receipts_for_rooms( |         events = yield self.store.get_linearized_receipts_for_rooms( | ||||||
|             rooms, |             room_ids, | ||||||
|             from_key=from_key, |             from_key=from_key, | ||||||
|             to_key=to_key, |             to_key=to_key, | ||||||
|         ) |         ) | ||||||
|  | |||||||
| @ -21,6 +21,7 @@ from synapse.api.constants import ( | |||||||
|     EventTypes, JoinRules, |     EventTypes, JoinRules, | ||||||
| ) | ) | ||||||
| from synapse.util.async import concurrently_execute | from synapse.util.async import concurrently_execute | ||||||
|  | from synapse.util.caches.descriptors import cachedInlineCallbacks | ||||||
| from synapse.util.caches.response_cache import ResponseCache | from synapse.util.caches.response_cache import ResponseCache | ||||||
| from synapse.types import ThirdPartyInstanceID | from synapse.types import ThirdPartyInstanceID | ||||||
| 
 | 
 | ||||||
| @ -62,6 +63,10 @@ class RoomListHandler(BaseHandler): | |||||||
|                 appservice and network id to use an appservice specific one. |                 appservice and network id to use an appservice specific one. | ||||||
|                 Setting to None returns all public rooms across all lists. |                 Setting to None returns all public rooms across all lists. | ||||||
|         """ |         """ | ||||||
|  |         logger.info( | ||||||
|  |             "Getting public room list: limit=%r, since=%r, search=%r, network=%r", | ||||||
|  |             limit, since_token, bool(search_filter), network_tuple, | ||||||
|  |         ) | ||||||
|         if search_filter: |         if search_filter: | ||||||
|             # We explicitly don't bother caching searches or requests for |             # We explicitly don't bother caching searches or requests for | ||||||
|             # appservice specific lists. |             # appservice specific lists. | ||||||
| @ -91,7 +96,6 @@ class RoomListHandler(BaseHandler): | |||||||
| 
 | 
 | ||||||
|         rooms_to_order_value = {} |         rooms_to_order_value = {} | ||||||
|         rooms_to_num_joined = {} |         rooms_to_num_joined = {} | ||||||
|         rooms_to_latest_event_ids = {} |  | ||||||
| 
 | 
 | ||||||
|         newly_visible = [] |         newly_visible = [] | ||||||
|         newly_unpublished = [] |         newly_unpublished = [] | ||||||
| @ -116,12 +120,18 @@ class RoomListHandler(BaseHandler): | |||||||
| 
 | 
 | ||||||
|         @defer.inlineCallbacks |         @defer.inlineCallbacks | ||||||
|         def get_order_for_room(room_id): |         def get_order_for_room(room_id): | ||||||
|             latest_event_ids = rooms_to_latest_event_ids.get(room_id, None) |             # Most of the rooms won't have changed between the since token and | ||||||
|             if not latest_event_ids: |             # now (especially if the since token is "now"). So, we can ask what | ||||||
|  |             # the current users are in a room (that will hit a cache) and then | ||||||
|  |             # check if the room has changed since the since token. (We have to | ||||||
|  |             # do it in that order to avoid races). | ||||||
|  |             # If things have changed then fall back to getting the current state | ||||||
|  |             # at the since token. | ||||||
|  |             joined_users = yield self.store.get_users_in_room(room_id) | ||||||
|  |             if self.store.has_room_changed_since(room_id, stream_token): | ||||||
|                 latest_event_ids = yield self.store.get_forward_extremeties_for_room( |                 latest_event_ids = yield self.store.get_forward_extremeties_for_room( | ||||||
|                     room_id, stream_token |                     room_id, stream_token | ||||||
|                 ) |                 ) | ||||||
|                 rooms_to_latest_event_ids[room_id] = latest_event_ids |  | ||||||
| 
 | 
 | ||||||
|                 if not latest_event_ids: |                 if not latest_event_ids: | ||||||
|                     return |                     return | ||||||
| @ -129,6 +139,7 @@ class RoomListHandler(BaseHandler): | |||||||
|                 joined_users = yield self.state_handler.get_current_user_in_room( |                 joined_users = yield self.state_handler.get_current_user_in_room( | ||||||
|                     room_id, latest_event_ids, |                     room_id, latest_event_ids, | ||||||
|                 ) |                 ) | ||||||
|  | 
 | ||||||
|             num_joined_users = len(joined_users) |             num_joined_users = len(joined_users) | ||||||
|             rooms_to_num_joined[room_id] = num_joined_users |             rooms_to_num_joined[room_id] = num_joined_users | ||||||
| 
 | 
 | ||||||
| @ -165,19 +176,19 @@ class RoomListHandler(BaseHandler): | |||||||
|                 rooms_to_scan = rooms_to_scan[:since_token.current_limit] |                 rooms_to_scan = rooms_to_scan[:since_token.current_limit] | ||||||
|                 rooms_to_scan.reverse() |                 rooms_to_scan.reverse() | ||||||
| 
 | 
 | ||||||
|         # Actually generate the entries. _generate_room_entry will append to |         # Actually generate the entries. _append_room_entry_to_chunk will append to | ||||||
|         # chunk but will stop if len(chunk) > limit |         # chunk but will stop if len(chunk) > limit | ||||||
|         chunk = [] |         chunk = [] | ||||||
|         if limit and not search_filter: |         if limit and not search_filter: | ||||||
|             step = limit + 1 |             step = limit + 1 | ||||||
|             for i in xrange(0, len(rooms_to_scan), step): |             for i in xrange(0, len(rooms_to_scan), step): | ||||||
|                 # We iterate here because the vast majority of cases we'll stop |                 # We iterate here because the vast majority of cases we'll stop | ||||||
|                 # at first iteration, but occaisonally _generate_room_entry |                 # at first iteration, but occaisonally _append_room_entry_to_chunk | ||||||
|                 # won't append to the chunk and so we need to loop again. |                 # won't append to the chunk and so we need to loop again. | ||||||
|                 # We don't want to scan over the entire range either as that |                 # We don't want to scan over the entire range either as that | ||||||
|                 # would potentially waste a lot of work. |                 # would potentially waste a lot of work. | ||||||
|                 yield concurrently_execute( |                 yield concurrently_execute( | ||||||
|                     lambda r: self._generate_room_entry( |                     lambda r: self._append_room_entry_to_chunk( | ||||||
|                         r, rooms_to_num_joined[r], |                         r, rooms_to_num_joined[r], | ||||||
|                         chunk, limit, search_filter |                         chunk, limit, search_filter | ||||||
|                     ), |                     ), | ||||||
| @ -187,7 +198,7 @@ class RoomListHandler(BaseHandler): | |||||||
|                     break |                     break | ||||||
|         else: |         else: | ||||||
|             yield concurrently_execute( |             yield concurrently_execute( | ||||||
|                 lambda r: self._generate_room_entry( |                 lambda r: self._append_room_entry_to_chunk( | ||||||
|                     r, rooms_to_num_joined[r], |                     r, rooms_to_num_joined[r], | ||||||
|                     chunk, limit, search_filter |                     chunk, limit, search_filter | ||||||
|                 ), |                 ), | ||||||
| @ -256,21 +267,35 @@ class RoomListHandler(BaseHandler): | |||||||
|         defer.returnValue(results) |         defer.returnValue(results) | ||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def _generate_room_entry(self, room_id, num_joined_users, chunk, limit, |     def _append_room_entry_to_chunk(self, room_id, num_joined_users, chunk, limit, | ||||||
|                                     search_filter): |                                     search_filter): | ||||||
|  |         """Generate the entry for a room in the public room list and append it | ||||||
|  |         to the `chunk` if it matches the search filter | ||||||
|  |         """ | ||||||
|         if limit and len(chunk) > limit + 1: |         if limit and len(chunk) > limit + 1: | ||||||
|             # We've already got enough, so lets just drop it. |             # We've already got enough, so lets just drop it. | ||||||
|             return |             return | ||||||
| 
 | 
 | ||||||
|  |         result = yield self._generate_room_entry(room_id, num_joined_users) | ||||||
|  | 
 | ||||||
|  |         if result and _matches_room_entry(result, search_filter): | ||||||
|  |             chunk.append(result) | ||||||
|  | 
 | ||||||
|  |     @cachedInlineCallbacks(num_args=1, cache_context=True) | ||||||
|  |     def _generate_room_entry(self, room_id, num_joined_users, cache_context): | ||||||
|  |         """Returns the entry for a room | ||||||
|  |         """ | ||||||
|         result = { |         result = { | ||||||
|             "room_id": room_id, |             "room_id": room_id, | ||||||
|             "num_joined_members": num_joined_users, |             "num_joined_members": num_joined_users, | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         current_state_ids = yield self.state_handler.get_current_state_ids(room_id) |         current_state_ids = yield self.store.get_current_state_ids( | ||||||
|  |             room_id, on_invalidate=cache_context.invalidate, | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         event_map = yield self.store.get_events([ |         event_map = yield self.store.get_events([ | ||||||
|             event_id for key, event_id in current_state_ids.items() |             event_id for key, event_id in current_state_ids.iteritems() | ||||||
|             if key[0] in ( |             if key[0] in ( | ||||||
|                 EventTypes.JoinRules, |                 EventTypes.JoinRules, | ||||||
|                 EventTypes.Name, |                 EventTypes.Name, | ||||||
| @ -294,7 +319,9 @@ class RoomListHandler(BaseHandler): | |||||||
|             if join_rule and join_rule != JoinRules.PUBLIC: |             if join_rule and join_rule != JoinRules.PUBLIC: | ||||||
|                 defer.returnValue(None) |                 defer.returnValue(None) | ||||||
| 
 | 
 | ||||||
|         aliases = yield self.store.get_aliases_for_room(room_id) |         aliases = yield self.store.get_aliases_for_room( | ||||||
|  |             room_id, on_invalidate=cache_context.invalidate | ||||||
|  |         ) | ||||||
|         if aliases: |         if aliases: | ||||||
|             result["aliases"] = aliases |             result["aliases"] = aliases | ||||||
| 
 | 
 | ||||||
| @ -334,8 +361,7 @@ class RoomListHandler(BaseHandler): | |||||||
|             if avatar_url: |             if avatar_url: | ||||||
|                 result["avatar_url"] = avatar_url |                 result["avatar_url"] = avatar_url | ||||||
| 
 | 
 | ||||||
|         if _matches_room_entry(result, search_filter): |         defer.returnValue(result) | ||||||
|             chunk.append(result) |  | ||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def get_remote_public_room_list(self, server_name, limit=None, since_token=None, |     def get_remote_public_room_list(self, server_name, limit=None, since_token=None, | ||||||
|  | |||||||
| @ -20,6 +20,7 @@ from synapse.util.metrics import Measure, measure_func | |||||||
| from synapse.util.caches.response_cache import ResponseCache | from synapse.util.caches.response_cache import ResponseCache | ||||||
| from synapse.push.clientformat import format_push_rules_for_user | from synapse.push.clientformat import format_push_rules_for_user | ||||||
| from synapse.visibility import filter_events_for_client | from synapse.visibility import filter_events_for_client | ||||||
|  | from synapse.types import RoomStreamToken | ||||||
| 
 | 
 | ||||||
| from twisted.internet import defer | from twisted.internet import defer | ||||||
| 
 | 
 | ||||||
| @ -225,8 +226,7 @@ class SyncHandler(object): | |||||||
|         with Measure(self.clock, "ephemeral_by_room"): |         with Measure(self.clock, "ephemeral_by_room"): | ||||||
|             typing_key = since_token.typing_key if since_token else "0" |             typing_key = since_token.typing_key if since_token else "0" | ||||||
| 
 | 
 | ||||||
|             rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string()) |             room_ids = yield self.store.get_rooms_for_user(sync_config.user.to_string()) | ||||||
|             room_ids = [room.room_id for room in rooms] |  | ||||||
| 
 | 
 | ||||||
|             typing_source = self.event_sources.sources["typing"] |             typing_source = self.event_sources.sources["typing"] | ||||||
|             typing, typing_key = yield typing_source.get_new_events( |             typing, typing_key = yield typing_source.get_new_events( | ||||||
| @ -568,16 +568,15 @@ class SyncHandler(object): | |||||||
|         since_token = sync_result_builder.since_token |         since_token = sync_result_builder.since_token | ||||||
| 
 | 
 | ||||||
|         if since_token and since_token.device_list_key: |         if since_token and since_token.device_list_key: | ||||||
|             rooms = yield self.store.get_rooms_for_user(user_id) |             room_ids = yield self.store.get_rooms_for_user(user_id) | ||||||
|             room_ids = set(r.room_id for r in rooms) |  | ||||||
| 
 | 
 | ||||||
|             user_ids_changed = set() |             user_ids_changed = set() | ||||||
|             changed = yield self.store.get_user_whose_devices_changed( |             changed = yield self.store.get_user_whose_devices_changed( | ||||||
|                 since_token.device_list_key |                 since_token.device_list_key | ||||||
|             ) |             ) | ||||||
|             for other_user_id in changed: |             for other_user_id in changed: | ||||||
|                 other_rooms = yield self.store.get_rooms_for_user(other_user_id) |                 other_room_ids = yield self.store.get_rooms_for_user(other_user_id) | ||||||
|                 if room_ids.intersection(e.room_id for e in other_rooms): |                 if room_ids.intersection(other_room_ids): | ||||||
|                     user_ids_changed.add(other_user_id) |                     user_ids_changed.add(other_user_id) | ||||||
| 
 | 
 | ||||||
|             defer.returnValue(user_ids_changed) |             defer.returnValue(user_ids_changed) | ||||||
| @ -721,14 +720,14 @@ class SyncHandler(object): | |||||||
|             extra_users_ids.update(users) |             extra_users_ids.update(users) | ||||||
|         extra_users_ids.discard(user.to_string()) |         extra_users_ids.discard(user.to_string()) | ||||||
| 
 | 
 | ||||||
|  |         if extra_users_ids: | ||||||
|             states = yield self.presence_handler.get_states( |             states = yield self.presence_handler.get_states( | ||||||
|                 extra_users_ids, |                 extra_users_ids, | ||||||
|             as_event=True, |  | ||||||
|             ) |             ) | ||||||
|             presence.extend(states) |             presence.extend(states) | ||||||
| 
 | 
 | ||||||
|             # Deduplicate the presence entries so that there's at most one per user |             # Deduplicate the presence entries so that there's at most one per user | ||||||
|         presence = {p["content"]["user_id"]: p for p in presence}.values() |             presence = {p.user_id: p for p in presence}.values() | ||||||
| 
 | 
 | ||||||
|         presence = sync_config.filter_collection.filter_presence( |         presence = sync_config.filter_collection.filter_presence( | ||||||
|             presence |             presence | ||||||
| @ -765,6 +764,21 @@ class SyncHandler(object): | |||||||
|             ) |             ) | ||||||
|             sync_result_builder.now_token = now_token |             sync_result_builder.now_token = now_token | ||||||
| 
 | 
 | ||||||
|  |         # We check up front if anything has changed, if it hasn't then there is | ||||||
|  |         # no point in going futher. | ||||||
|  |         since_token = sync_result_builder.since_token | ||||||
|  |         if not sync_result_builder.full_state: | ||||||
|  |             if since_token and not ephemeral_by_room and not account_data_by_room: | ||||||
|  |                 have_changed = yield self._have_rooms_changed(sync_result_builder) | ||||||
|  |                 if not have_changed: | ||||||
|  |                     tags_by_room = yield self.store.get_updated_tags( | ||||||
|  |                         user_id, | ||||||
|  |                         since_token.account_data_key, | ||||||
|  |                     ) | ||||||
|  |                     if not tags_by_room: | ||||||
|  |                         logger.debug("no-oping sync") | ||||||
|  |                         defer.returnValue(([], [])) | ||||||
|  | 
 | ||||||
|         ignored_account_data = yield self.store.get_global_account_data_by_type_for_user( |         ignored_account_data = yield self.store.get_global_account_data_by_type_for_user( | ||||||
|             "m.ignored_user_list", user_id=user_id, |             "m.ignored_user_list", user_id=user_id, | ||||||
|         ) |         ) | ||||||
| @ -774,13 +788,12 @@ class SyncHandler(object): | |||||||
|         else: |         else: | ||||||
|             ignored_users = frozenset() |             ignored_users = frozenset() | ||||||
| 
 | 
 | ||||||
|         if sync_result_builder.since_token: |         if since_token: | ||||||
|             res = yield self._get_rooms_changed(sync_result_builder, ignored_users) |             res = yield self._get_rooms_changed(sync_result_builder, ignored_users) | ||||||
|             room_entries, invited, newly_joined_rooms = res |             room_entries, invited, newly_joined_rooms = res | ||||||
| 
 | 
 | ||||||
|             tags_by_room = yield self.store.get_updated_tags( |             tags_by_room = yield self.store.get_updated_tags( | ||||||
|                 user_id, |                 user_id, since_token.account_data_key, | ||||||
|                 sync_result_builder.since_token.account_data_key, |  | ||||||
|             ) |             ) | ||||||
|         else: |         else: | ||||||
|             res = yield self._get_all_rooms(sync_result_builder, ignored_users) |             res = yield self._get_all_rooms(sync_result_builder, ignored_users) | ||||||
| @ -805,7 +818,7 @@ class SyncHandler(object): | |||||||
| 
 | 
 | ||||||
|         # Now we want to get any newly joined users |         # Now we want to get any newly joined users | ||||||
|         newly_joined_users = set() |         newly_joined_users = set() | ||||||
|         if sync_result_builder.since_token: |         if since_token: | ||||||
|             for joined_sync in sync_result_builder.joined: |             for joined_sync in sync_result_builder.joined: | ||||||
|                 it = itertools.chain( |                 it = itertools.chain( | ||||||
|                     joined_sync.timeline.events, joined_sync.state.values() |                     joined_sync.timeline.events, joined_sync.state.values() | ||||||
| @ -817,6 +830,38 @@ class SyncHandler(object): | |||||||
| 
 | 
 | ||||||
|         defer.returnValue((newly_joined_rooms, newly_joined_users)) |         defer.returnValue((newly_joined_rooms, newly_joined_users)) | ||||||
| 
 | 
 | ||||||
|  |     @defer.inlineCallbacks | ||||||
|  |     def _have_rooms_changed(self, sync_result_builder): | ||||||
|  |         """Returns whether there may be any new events that should be sent down | ||||||
|  |         the sync. Returns True if there are. | ||||||
|  |         """ | ||||||
|  |         user_id = sync_result_builder.sync_config.user.to_string() | ||||||
|  |         since_token = sync_result_builder.since_token | ||||||
|  |         now_token = sync_result_builder.now_token | ||||||
|  | 
 | ||||||
|  |         assert since_token | ||||||
|  | 
 | ||||||
|  |         # Get a list of membership change events that have happened. | ||||||
|  |         rooms_changed = yield self.store.get_membership_changes_for_user( | ||||||
|  |             user_id, since_token.room_key, now_token.room_key | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         if rooms_changed: | ||||||
|  |             defer.returnValue(True) | ||||||
|  | 
 | ||||||
|  |         app_service = self.store.get_app_service_by_user_id(user_id) | ||||||
|  |         if app_service: | ||||||
|  |             rooms = yield self.store.get_app_service_rooms(app_service) | ||||||
|  |             joined_room_ids = set(r.room_id for r in rooms) | ||||||
|  |         else: | ||||||
|  |             joined_room_ids = yield self.store.get_rooms_for_user(user_id) | ||||||
|  | 
 | ||||||
|  |         stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream | ||||||
|  |         for room_id in joined_room_ids: | ||||||
|  |             if self.store.has_room_changed_since(room_id, stream_id): | ||||||
|  |                 defer.returnValue(True) | ||||||
|  |         defer.returnValue(False) | ||||||
|  | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def _get_rooms_changed(self, sync_result_builder, ignored_users): |     def _get_rooms_changed(self, sync_result_builder, ignored_users): | ||||||
|         """Gets the the changes that have happened since the last sync. |         """Gets the the changes that have happened since the last sync. | ||||||
| @ -841,8 +886,7 @@ class SyncHandler(object): | |||||||
|             rooms = yield self.store.get_app_service_rooms(app_service) |             rooms = yield self.store.get_app_service_rooms(app_service) | ||||||
|             joined_room_ids = set(r.room_id for r in rooms) |             joined_room_ids = set(r.room_id for r in rooms) | ||||||
|         else: |         else: | ||||||
|             rooms = yield self.store.get_rooms_for_user(user_id) |             joined_room_ids = yield self.store.get_rooms_for_user(user_id) | ||||||
|             joined_room_ids = set(r.room_id for r in rooms) |  | ||||||
| 
 | 
 | ||||||
|         # Get a list of membership change events that have happened. |         # Get a list of membership change events that have happened. | ||||||
|         rooms_changed = yield self.store.get_membership_changes_for_user( |         rooms_changed = yield self.store.get_membership_changes_for_user( | ||||||
|  | |||||||
| @ -12,8 +12,7 @@ | |||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| 
 | import synapse.util.retryutils | ||||||
| 
 |  | ||||||
| from twisted.internet import defer, reactor, protocol | from twisted.internet import defer, reactor, protocol | ||||||
| from twisted.internet.error import DNSLookupError | from twisted.internet.error import DNSLookupError | ||||||
| from twisted.web.client import readBody, HTTPConnectionPool, Agent | from twisted.web.client import readBody, HTTPConnectionPool, Agent | ||||||
| @ -22,7 +21,7 @@ from twisted.web._newclient import ResponseDone | |||||||
| 
 | 
 | ||||||
| from synapse.http.endpoint import matrix_federation_endpoint | from synapse.http.endpoint import matrix_federation_endpoint | ||||||
| from synapse.util.async import sleep | from synapse.util.async import sleep | ||||||
| from synapse.util.logcontext import preserve_context_over_fn | from synapse.util import logcontext | ||||||
| import synapse.metrics | import synapse.metrics | ||||||
| 
 | 
 | ||||||
| from canonicaljson import encode_canonical_json | from canonicaljson import encode_canonical_json | ||||||
| @ -94,6 +93,7 @@ class MatrixFederationHttpClient(object): | |||||||
|             reactor, MatrixFederationEndpointFactory(hs), pool=pool |             reactor, MatrixFederationEndpointFactory(hs), pool=pool | ||||||
|         ) |         ) | ||||||
|         self.clock = hs.get_clock() |         self.clock = hs.get_clock() | ||||||
|  |         self._store = hs.get_datastore() | ||||||
|         self.version_string = hs.version_string |         self.version_string = hs.version_string | ||||||
|         self._next_id = 1 |         self._next_id = 1 | ||||||
| 
 | 
 | ||||||
| @ -103,12 +103,40 @@ class MatrixFederationHttpClient(object): | |||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def _create_request(self, destination, method, path_bytes, |     def _request(self, destination, method, path, | ||||||
|                  body_callback, headers_dict={}, param_bytes=b"", |                  body_callback, headers_dict={}, param_bytes=b"", | ||||||
|                  query_bytes=b"", retry_on_dns_fail=True, |                  query_bytes=b"", retry_on_dns_fail=True, | ||||||
|                         timeout=None, long_retries=False): |                  timeout=None, long_retries=False, | ||||||
|         """ Creates and sends a request to the given url |                  ignore_backoff=False, | ||||||
|  |                  backoff_on_404=False): | ||||||
|  |         """ Creates and sends a request to the given server | ||||||
|  |         Args: | ||||||
|  |             destination (str): The remote server to send the HTTP request to. | ||||||
|  |             method (str): HTTP method | ||||||
|  |             path (str): The HTTP path | ||||||
|  |             ignore_backoff (bool): true to ignore the historical backoff data | ||||||
|  |                 and try the request anyway. | ||||||
|  |             backoff_on_404 (bool): Back off if we get a 404 | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             Deferred: resolves with the http response object on success. | ||||||
|  | 
 | ||||||
|  |             Fails with ``HTTPRequestException``: if we get an HTTP response | ||||||
|  |                 code >= 300. | ||||||
|  |             Fails with ``NotRetryingDestination`` if we are not yet ready | ||||||
|  |                 to retry this server. | ||||||
|         """ |         """ | ||||||
|  |         limiter = yield synapse.util.retryutils.get_retry_limiter( | ||||||
|  |             destination, | ||||||
|  |             self.clock, | ||||||
|  |             self._store, | ||||||
|  |             backoff_on_404=backoff_on_404, | ||||||
|  |             ignore_backoff=ignore_backoff, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         destination = destination.encode("ascii") | ||||||
|  |         path_bytes = path.encode("ascii") | ||||||
|  |         with limiter: | ||||||
|             headers_dict[b"User-Agent"] = [self.version_string] |             headers_dict[b"User-Agent"] = [self.version_string] | ||||||
|             headers_dict[b"Host"] = [destination] |             headers_dict[b"Host"] = [destination] | ||||||
| 
 | 
 | ||||||
| @ -144,8 +172,7 @@ class MatrixFederationHttpClient(object): | |||||||
| 
 | 
 | ||||||
|                     try: |                     try: | ||||||
|                         def send_request(): |                         def send_request(): | ||||||
|                         request_deferred = preserve_context_over_fn( |                             request_deferred = self.agent.request( | ||||||
|                             self.agent.request, |  | ||||||
|                                 method, |                                 method, | ||||||
|                                 url_bytes, |                                 url_bytes, | ||||||
|                                 Headers(headers_dict), |                                 Headers(headers_dict), | ||||||
| @ -157,7 +184,8 @@ class MatrixFederationHttpClient(object): | |||||||
|                                 time_out=timeout / 1000. if timeout else 60, |                                 time_out=timeout / 1000. if timeout else 60, | ||||||
|                             ) |                             ) | ||||||
| 
 | 
 | ||||||
|                     response = yield preserve_context_over_fn(send_request) |                         with logcontext.PreserveLoggingContext(): | ||||||
|  |                             response = yield send_request() | ||||||
| 
 | 
 | ||||||
|                         log_result = "%d %s" % (response.code, response.phrase,) |                         log_result = "%d %s" % (response.code, response.phrase,) | ||||||
|                         break |                         break | ||||||
| @ -214,7 +242,8 @@ class MatrixFederationHttpClient(object): | |||||||
|             else: |             else: | ||||||
|                 # :'( |                 # :'( | ||||||
|                 # Update transactions table? |                 # Update transactions table? | ||||||
|             body = yield preserve_context_over_fn(readBody, response) |                 with logcontext.PreserveLoggingContext(): | ||||||
|  |                     body = yield readBody(response) | ||||||
|                 raise HttpResponseException( |                 raise HttpResponseException( | ||||||
|                     response.code, response.phrase, body |                     response.code, response.phrase, body | ||||||
|                 ) |                 ) | ||||||
| @ -248,7 +277,9 @@ class MatrixFederationHttpClient(object): | |||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def put_json(self, destination, path, data={}, json_data_callback=None, |     def put_json(self, destination, path, data={}, json_data_callback=None, | ||||||
|                  long_retries=False, timeout=None): |                  long_retries=False, timeout=None, | ||||||
|  |                  ignore_backoff=False, | ||||||
|  |                  backoff_on_404=False): | ||||||
|         """ Sends the specifed json data using PUT |         """ Sends the specifed json data using PUT | ||||||
| 
 | 
 | ||||||
|         Args: |         Args: | ||||||
| @ -263,11 +294,19 @@ class MatrixFederationHttpClient(object): | |||||||
|                 retry for a short or long time. |                 retry for a short or long time. | ||||||
|             timeout(int): How long to try (in ms) the destination for before |             timeout(int): How long to try (in ms) the destination for before | ||||||
|                 giving up. None indicates no timeout. |                 giving up. None indicates no timeout. | ||||||
|  |             ignore_backoff (bool): true to ignore the historical backoff data | ||||||
|  |                 and try the request anyway. | ||||||
|  |             backoff_on_404 (bool): True if we should count a 404 response as | ||||||
|  |                 a failure of the server (and should therefore back off future | ||||||
|  |                 requests) | ||||||
| 
 | 
 | ||||||
|         Returns: |         Returns: | ||||||
|             Deferred: Succeeds when we get a 2xx HTTP response. The result |             Deferred: Succeeds when we get a 2xx HTTP response. The result | ||||||
|             will be the decoded JSON body. On a 4xx or 5xx error response a |             will be the decoded JSON body. On a 4xx or 5xx error response a | ||||||
|             CodeMessageException is raised. |             CodeMessageException is raised. | ||||||
|  | 
 | ||||||
|  |             Fails with ``NotRetryingDestination`` if we are not yet ready | ||||||
|  |             to retry this server. | ||||||
|         """ |         """ | ||||||
| 
 | 
 | ||||||
|         if not json_data_callback: |         if not json_data_callback: | ||||||
| @ -282,26 +321,29 @@ class MatrixFederationHttpClient(object): | |||||||
|             producer = _JsonProducer(json_data) |             producer = _JsonProducer(json_data) | ||||||
|             return producer |             return producer | ||||||
| 
 | 
 | ||||||
|         response = yield self._create_request( |         response = yield self._request( | ||||||
|             destination.encode("ascii"), |             destination, | ||||||
|             "PUT", |             "PUT", | ||||||
|             path.encode("ascii"), |             path, | ||||||
|             body_callback=body_callback, |             body_callback=body_callback, | ||||||
|             headers_dict={"Content-Type": ["application/json"]}, |             headers_dict={"Content-Type": ["application/json"]}, | ||||||
|             long_retries=long_retries, |             long_retries=long_retries, | ||||||
|             timeout=timeout, |             timeout=timeout, | ||||||
|  |             ignore_backoff=ignore_backoff, | ||||||
|  |             backoff_on_404=backoff_on_404, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         if 200 <= response.code < 300: |         if 200 <= response.code < 300: | ||||||
|             # We need to update the transactions table to say it was sent? |             # We need to update the transactions table to say it was sent? | ||||||
|             check_content_type_is_json(response.headers) |             check_content_type_is_json(response.headers) | ||||||
| 
 | 
 | ||||||
|         body = yield preserve_context_over_fn(readBody, response) |         with logcontext.PreserveLoggingContext(): | ||||||
|  |             body = yield readBody(response) | ||||||
|         defer.returnValue(json.loads(body)) |         defer.returnValue(json.loads(body)) | ||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def post_json(self, destination, path, data={}, long_retries=False, |     def post_json(self, destination, path, data={}, long_retries=False, | ||||||
|                   timeout=None): |                   timeout=None, ignore_backoff=False): | ||||||
|         """ Sends the specifed json data using POST |         """ Sends the specifed json data using POST | ||||||
| 
 | 
 | ||||||
|         Args: |         Args: | ||||||
| @ -314,11 +356,15 @@ class MatrixFederationHttpClient(object): | |||||||
|                 retry for a short or long time. |                 retry for a short or long time. | ||||||
|             timeout(int): How long to try (in ms) the destination for before |             timeout(int): How long to try (in ms) the destination for before | ||||||
|                 giving up. None indicates no timeout. |                 giving up. None indicates no timeout. | ||||||
| 
 |             ignore_backoff (bool): true to ignore the historical backoff data and | ||||||
|  |                 try the request anyway. | ||||||
|         Returns: |         Returns: | ||||||
|             Deferred: Succeeds when we get a 2xx HTTP response. The result |             Deferred: Succeeds when we get a 2xx HTTP response. The result | ||||||
|             will be the decoded JSON body. On a 4xx or 5xx error response a |             will be the decoded JSON body. On a 4xx or 5xx error response a | ||||||
|             CodeMessageException is raised. |             CodeMessageException is raised. | ||||||
|  | 
 | ||||||
|  |             Fails with ``NotRetryingDestination`` if we are not yet ready | ||||||
|  |             to retry this server. | ||||||
|         """ |         """ | ||||||
| 
 | 
 | ||||||
|         def body_callback(method, url_bytes, headers_dict): |         def body_callback(method, url_bytes, headers_dict): | ||||||
| @ -327,27 +373,29 @@ class MatrixFederationHttpClient(object): | |||||||
|             ) |             ) | ||||||
|             return _JsonProducer(data) |             return _JsonProducer(data) | ||||||
| 
 | 
 | ||||||
|         response = yield self._create_request( |         response = yield self._request( | ||||||
|             destination.encode("ascii"), |             destination, | ||||||
|             "POST", |             "POST", | ||||||
|             path.encode("ascii"), |             path, | ||||||
|             body_callback=body_callback, |             body_callback=body_callback, | ||||||
|             headers_dict={"Content-Type": ["application/json"]}, |             headers_dict={"Content-Type": ["application/json"]}, | ||||||
|             long_retries=long_retries, |             long_retries=long_retries, | ||||||
|             timeout=timeout, |             timeout=timeout, | ||||||
|  |             ignore_backoff=ignore_backoff, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         if 200 <= response.code < 300: |         if 200 <= response.code < 300: | ||||||
|             # We need to update the transactions table to say it was sent? |             # We need to update the transactions table to say it was sent? | ||||||
|             check_content_type_is_json(response.headers) |             check_content_type_is_json(response.headers) | ||||||
| 
 | 
 | ||||||
|         body = yield preserve_context_over_fn(readBody, response) |         with logcontext.PreserveLoggingContext(): | ||||||
|  |             body = yield readBody(response) | ||||||
| 
 | 
 | ||||||
|         defer.returnValue(json.loads(body)) |         defer.returnValue(json.loads(body)) | ||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def get_json(self, destination, path, args={}, retry_on_dns_fail=True, |     def get_json(self, destination, path, args={}, retry_on_dns_fail=True, | ||||||
|                  timeout=None): |                  timeout=None, ignore_backoff=False): | ||||||
|         """ GETs some json from the given host homeserver and path |         """ GETs some json from the given host homeserver and path | ||||||
| 
 | 
 | ||||||
|         Args: |         Args: | ||||||
| @ -359,11 +407,16 @@ class MatrixFederationHttpClient(object): | |||||||
|             timeout (int): How long to try (in ms) the destination for before |             timeout (int): How long to try (in ms) the destination for before | ||||||
|                 giving up. None indicates no timeout and that the request will |                 giving up. None indicates no timeout and that the request will | ||||||
|                 be retried. |                 be retried. | ||||||
|  |             ignore_backoff (bool): true to ignore the historical backoff data | ||||||
|  |                 and try the request anyway. | ||||||
|         Returns: |         Returns: | ||||||
|             Deferred: Succeeds when we get *any* HTTP response. |             Deferred: Succeeds when we get *any* HTTP response. | ||||||
| 
 | 
 | ||||||
|             The result of the deferred is a tuple of `(code, response)`, |             The result of the deferred is a tuple of `(code, response)`, | ||||||
|             where `response` is a dict representing the decoded JSON body. |             where `response` is a dict representing the decoded JSON body. | ||||||
|  | 
 | ||||||
|  |             Fails with ``NotRetryingDestination`` if we are not yet ready | ||||||
|  |             to retry this server. | ||||||
|         """ |         """ | ||||||
|         logger.debug("get_json args: %s", args) |         logger.debug("get_json args: %s", args) | ||||||
| 
 | 
 | ||||||
| @ -380,36 +433,47 @@ class MatrixFederationHttpClient(object): | |||||||
|             self.sign_request(destination, method, url_bytes, headers_dict) |             self.sign_request(destination, method, url_bytes, headers_dict) | ||||||
|             return None |             return None | ||||||
| 
 | 
 | ||||||
|         response = yield self._create_request( |         response = yield self._request( | ||||||
|             destination.encode("ascii"), |             destination, | ||||||
|             "GET", |             "GET", | ||||||
|             path.encode("ascii"), |             path, | ||||||
|             query_bytes=query_bytes, |             query_bytes=query_bytes, | ||||||
|             body_callback=body_callback, |             body_callback=body_callback, | ||||||
|             retry_on_dns_fail=retry_on_dns_fail, |             retry_on_dns_fail=retry_on_dns_fail, | ||||||
|             timeout=timeout, |             timeout=timeout, | ||||||
|  |             ignore_backoff=ignore_backoff, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         if 200 <= response.code < 300: |         if 200 <= response.code < 300: | ||||||
|             # We need to update the transactions table to say it was sent? |             # We need to update the transactions table to say it was sent? | ||||||
|             check_content_type_is_json(response.headers) |             check_content_type_is_json(response.headers) | ||||||
| 
 | 
 | ||||||
|         body = yield preserve_context_over_fn(readBody, response) |         with logcontext.PreserveLoggingContext(): | ||||||
|  |             body = yield readBody(response) | ||||||
| 
 | 
 | ||||||
|         defer.returnValue(json.loads(body)) |         defer.returnValue(json.loads(body)) | ||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def get_file(self, destination, path, output_stream, args={}, |     def get_file(self, destination, path, output_stream, args={}, | ||||||
|                  retry_on_dns_fail=True, max_size=None): |                  retry_on_dns_fail=True, max_size=None, | ||||||
|  |                  ignore_backoff=False): | ||||||
|         """GETs a file from a given homeserver |         """GETs a file from a given homeserver | ||||||
|         Args: |         Args: | ||||||
|             destination (str): The remote server to send the HTTP request to. |             destination (str): The remote server to send the HTTP request to. | ||||||
|             path (str): The HTTP path to GET. |             path (str): The HTTP path to GET. | ||||||
|             output_stream (file): File to write the response body to. |             output_stream (file): File to write the response body to. | ||||||
|             args (dict): Optional dictionary used to create the query string. |             args (dict): Optional dictionary used to create the query string. | ||||||
|  |             ignore_backoff (bool): true to ignore the historical backoff data | ||||||
|  |                 and try the request anyway. | ||||||
|         Returns: |         Returns: | ||||||
|             A (int,dict) tuple of the file length and a dict of the response |             Deferred: resolves with an (int,dict) tuple of the file length and | ||||||
|             headers. |             a dict of the response headers. | ||||||
|  | 
 | ||||||
|  |             Fails with ``HTTPRequestException`` if we get an HTTP response code | ||||||
|  |             >= 300 | ||||||
|  | 
 | ||||||
|  |             Fails with ``NotRetryingDestination`` if we are not yet ready | ||||||
|  |             to retry this server. | ||||||
|         """ |         """ | ||||||
| 
 | 
 | ||||||
|         encoded_args = {} |         encoded_args = {} | ||||||
| @ -419,26 +483,27 @@ class MatrixFederationHttpClient(object): | |||||||
|             encoded_args[k] = [v.encode("UTF-8") for v in vs] |             encoded_args[k] = [v.encode("UTF-8") for v in vs] | ||||||
| 
 | 
 | ||||||
|         query_bytes = urllib.urlencode(encoded_args, True) |         query_bytes = urllib.urlencode(encoded_args, True) | ||||||
|         logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail) |         logger.debug("Query bytes: %s Retry DNS: %s", query_bytes, retry_on_dns_fail) | ||||||
| 
 | 
 | ||||||
|         def body_callback(method, url_bytes, headers_dict): |         def body_callback(method, url_bytes, headers_dict): | ||||||
|             self.sign_request(destination, method, url_bytes, headers_dict) |             self.sign_request(destination, method, url_bytes, headers_dict) | ||||||
|             return None |             return None | ||||||
| 
 | 
 | ||||||
|         response = yield self._create_request( |         response = yield self._request( | ||||||
|             destination.encode("ascii"), |             destination, | ||||||
|             "GET", |             "GET", | ||||||
|             path.encode("ascii"), |             path, | ||||||
|             query_bytes=query_bytes, |             query_bytes=query_bytes, | ||||||
|             body_callback=body_callback, |             body_callback=body_callback, | ||||||
|             retry_on_dns_fail=retry_on_dns_fail |             retry_on_dns_fail=retry_on_dns_fail, | ||||||
|  |             ignore_backoff=ignore_backoff, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         headers = dict(response.headers.getAllRawHeaders()) |         headers = dict(response.headers.getAllRawHeaders()) | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|             length = yield preserve_context_over_fn( |             with logcontext.PreserveLoggingContext(): | ||||||
|                 _readBodyToFile, |                 length = yield _readBodyToFile( | ||||||
|                     response, output_stream, max_size |                     response, output_stream, max_size | ||||||
|                 ) |                 ) | ||||||
|         except: |         except: | ||||||
|  | |||||||
| @ -192,6 +192,16 @@ def parse_json_object_from_request(request): | |||||||
|     return content |     return content | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def assert_params_in_request(body, required): | ||||||
|  |     absent = [] | ||||||
|  |     for k in required: | ||||||
|  |         if k not in body: | ||||||
|  |             absent.append(k) | ||||||
|  | 
 | ||||||
|  |     if len(absent) > 0: | ||||||
|  |         raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class RestServlet(object): | class RestServlet(object): | ||||||
| 
 | 
 | ||||||
|     """ A Synapse REST Servlet. |     """ A Synapse REST Servlet. | ||||||
|  | |||||||
| @ -16,6 +16,7 @@ | |||||||
| from twisted.internet import defer | from twisted.internet import defer | ||||||
| from synapse.api.constants import EventTypes, Membership | from synapse.api.constants import EventTypes, Membership | ||||||
| from synapse.api.errors import AuthError | from synapse.api.errors import AuthError | ||||||
|  | from synapse.handlers.presence import format_user_presence_state | ||||||
| 
 | 
 | ||||||
| from synapse.util import DeferredTimedOutError | from synapse.util import DeferredTimedOutError | ||||||
| from synapse.util.logutils import log_function | from synapse.util.logutils import log_function | ||||||
| @ -37,6 +38,10 @@ metrics = synapse.metrics.get_metrics_for(__name__) | |||||||
| 
 | 
 | ||||||
| notified_events_counter = metrics.register_counter("notified_events") | notified_events_counter = metrics.register_counter("notified_events") | ||||||
| 
 | 
 | ||||||
|  | users_woken_by_stream_counter = metrics.register_counter( | ||||||
|  |     "users_woken_by_stream", labels=["stream"] | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| # TODO(paul): Should be shared somewhere | # TODO(paul): Should be shared somewhere | ||||||
| def count(func, l): | def count(func, l): | ||||||
| @ -73,6 +78,13 @@ class _NotifierUserStream(object): | |||||||
|         self.user_id = user_id |         self.user_id = user_id | ||||||
|         self.rooms = set(rooms) |         self.rooms = set(rooms) | ||||||
|         self.current_token = current_token |         self.current_token = current_token | ||||||
|  | 
 | ||||||
|  |         # The last token for which we should wake up any streams that have a | ||||||
|  |         # token that comes before it. This gets updated everytime we get poked. | ||||||
|  |         # We start it at the current token since if we get any streams | ||||||
|  |         # that have a token from before we have no idea whether they should be | ||||||
|  |         # woken up or not, so lets just wake them up. | ||||||
|  |         self.last_notified_token = current_token | ||||||
|         self.last_notified_ms = time_now_ms |         self.last_notified_ms = time_now_ms | ||||||
| 
 | 
 | ||||||
|         with PreserveLoggingContext(): |         with PreserveLoggingContext(): | ||||||
| @ -89,9 +101,12 @@ class _NotifierUserStream(object): | |||||||
|         self.current_token = self.current_token.copy_and_advance( |         self.current_token = self.current_token.copy_and_advance( | ||||||
|             stream_key, stream_id |             stream_key, stream_id | ||||||
|         ) |         ) | ||||||
|  |         self.last_notified_token = self.current_token | ||||||
|         self.last_notified_ms = time_now_ms |         self.last_notified_ms = time_now_ms | ||||||
|         noify_deferred = self.notify_deferred |         noify_deferred = self.notify_deferred | ||||||
| 
 | 
 | ||||||
|  |         users_woken_by_stream_counter.inc(stream_key) | ||||||
|  | 
 | ||||||
|         with PreserveLoggingContext(): |         with PreserveLoggingContext(): | ||||||
|             self.notify_deferred = ObservableDeferred(defer.Deferred()) |             self.notify_deferred = ObservableDeferred(defer.Deferred()) | ||||||
|             noify_deferred.callback(self.current_token) |             noify_deferred.callback(self.current_token) | ||||||
| @ -113,8 +128,14 @@ class _NotifierUserStream(object): | |||||||
|     def new_listener(self, token): |     def new_listener(self, token): | ||||||
|         """Returns a deferred that is resolved when there is a new token |         """Returns a deferred that is resolved when there is a new token | ||||||
|         greater than the given token. |         greater than the given token. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             token: The token from which we are streaming from, i.e. we shouldn't | ||||||
|  |                 notify for things that happened before this. | ||||||
|         """ |         """ | ||||||
|         if self.current_token.is_after(token): |         # Immediately wake up stream if something has already since happened | ||||||
|  |         # since their last token. | ||||||
|  |         if self.last_notified_token.is_after(token): | ||||||
|             return _NotificationListener(defer.succeed(self.current_token)) |             return _NotificationListener(defer.succeed(self.current_token)) | ||||||
|         else: |         else: | ||||||
|             return _NotificationListener(self.notify_deferred.observe()) |             return _NotificationListener(self.notify_deferred.observe()) | ||||||
| @ -283,8 +304,7 @@ class Notifier(object): | |||||||
|         if user_stream is None: |         if user_stream is None: | ||||||
|             current_token = yield self.event_sources.get_current_token() |             current_token = yield self.event_sources.get_current_token() | ||||||
|             if room_ids is None: |             if room_ids is None: | ||||||
|                 rooms = yield self.store.get_rooms_for_user(user_id) |                 room_ids = yield self.store.get_rooms_for_user(user_id) | ||||||
|                 room_ids = [room.room_id for room in rooms] |  | ||||||
|             user_stream = _NotifierUserStream( |             user_stream = _NotifierUserStream( | ||||||
|                 user_id=user_id, |                 user_id=user_id, | ||||||
|                 rooms=room_ids, |                 rooms=room_ids, | ||||||
| @ -294,40 +314,44 @@ class Notifier(object): | |||||||
|             self._register_with_keys(user_stream) |             self._register_with_keys(user_stream) | ||||||
| 
 | 
 | ||||||
|         result = None |         result = None | ||||||
|  |         prev_token = from_token | ||||||
|         if timeout: |         if timeout: | ||||||
|             end_time = self.clock.time_msec() + timeout |             end_time = self.clock.time_msec() + timeout | ||||||
| 
 | 
 | ||||||
|             prev_token = from_token |  | ||||||
|             while not result: |             while not result: | ||||||
|                 try: |                 try: | ||||||
|                     current_token = user_stream.current_token |  | ||||||
| 
 |  | ||||||
|                     result = yield callback(prev_token, current_token) |  | ||||||
|                     if result: |  | ||||||
|                         break |  | ||||||
| 
 |  | ||||||
|                     now = self.clock.time_msec() |                     now = self.clock.time_msec() | ||||||
|                     if end_time <= now: |                     if end_time <= now: | ||||||
|                         break |                         break | ||||||
| 
 | 
 | ||||||
|                     # Now we wait for the _NotifierUserStream to be told there |                     # Now we wait for the _NotifierUserStream to be told there | ||||||
|                     # is a new token. |                     # is a new token. | ||||||
|                     # We need to supply the token we supplied to callback so |  | ||||||
|                     # that we don't miss any current_token updates. |  | ||||||
|                     prev_token = current_token |  | ||||||
|                     listener = user_stream.new_listener(prev_token) |                     listener = user_stream.new_listener(prev_token) | ||||||
|                     with PreserveLoggingContext(): |                     with PreserveLoggingContext(): | ||||||
|                         yield self.clock.time_bound_deferred( |                         yield self.clock.time_bound_deferred( | ||||||
|                             listener.deferred, |                             listener.deferred, | ||||||
|                             time_out=(end_time - now) / 1000. |                             time_out=(end_time - now) / 1000. | ||||||
|                         ) |                         ) | ||||||
|  | 
 | ||||||
|  |                     current_token = user_stream.current_token | ||||||
|  | 
 | ||||||
|  |                     result = yield callback(prev_token, current_token) | ||||||
|  |                     if result: | ||||||
|  |                         break | ||||||
|  | 
 | ||||||
|  |                     # Update the prev_token to the current_token since nothing | ||||||
|  |                     # has happened between the old prev_token and the current_token | ||||||
|  |                     prev_token = current_token | ||||||
|                 except DeferredTimedOutError: |                 except DeferredTimedOutError: | ||||||
|                     break |                     break | ||||||
|                 except defer.CancelledError: |                 except defer.CancelledError: | ||||||
|                     break |                     break | ||||||
|         else: | 
 | ||||||
|  |         if result is None: | ||||||
|  |             # This happened if there was no timeout or if the timeout had | ||||||
|  |             # already expired. | ||||||
|             current_token = user_stream.current_token |             current_token = user_stream.current_token | ||||||
|             result = yield callback(from_token, current_token) |             result = yield callback(prev_token, current_token) | ||||||
| 
 | 
 | ||||||
|         defer.returnValue(result) |         defer.returnValue(result) | ||||||
| 
 | 
 | ||||||
| @ -388,6 +412,15 @@ class Notifier(object): | |||||||
|                         new_events, |                         new_events, | ||||||
|                         is_peeking=is_peeking, |                         is_peeking=is_peeking, | ||||||
|                     ) |                     ) | ||||||
|  |                 elif name == "presence": | ||||||
|  |                     now = self.clock.time_msec() | ||||||
|  |                     new_events[:] = [ | ||||||
|  |                         { | ||||||
|  |                             "type": "m.presence", | ||||||
|  |                             "content": format_user_presence_state(event, now), | ||||||
|  |                         } | ||||||
|  |                         for event in new_events | ||||||
|  |                     ] | ||||||
| 
 | 
 | ||||||
|                 events.extend(new_events) |                 events.extend(new_events) | ||||||
|                 end_token = end_token.copy_and_replace(keyname, new_key) |                 end_token = end_token.copy_and_replace(keyname, new_key) | ||||||
| @ -420,8 +453,7 @@ class Notifier(object): | |||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def _get_room_ids(self, user, explicit_room_id): |     def _get_room_ids(self, user, explicit_room_id): | ||||||
|         joined_rooms = yield self.store.get_rooms_for_user(user.to_string()) |         joined_room_ids = yield self.store.get_rooms_for_user(user.to_string()) | ||||||
|         joined_room_ids = map(lambda r: r.room_id, joined_rooms) |  | ||||||
|         if explicit_room_id: |         if explicit_room_id: | ||||||
|             if explicit_room_id in joined_room_ids: |             if explicit_room_id in joined_room_ids: | ||||||
|                 defer.returnValue(([explicit_room_id], True)) |                 defer.returnValue(([explicit_room_id], True)) | ||||||
|  | |||||||
| @ -139,7 +139,7 @@ class Mailer(object): | |||||||
| 
 | 
 | ||||||
|         @defer.inlineCallbacks |         @defer.inlineCallbacks | ||||||
|         def _fetch_room_state(room_id): |         def _fetch_room_state(room_id): | ||||||
|             room_state = yield self.state_handler.get_current_state_ids(room_id) |             room_state = yield self.store.get_current_state_ids(room_id) | ||||||
|             state_by_room[room_id] = room_state |             state_by_room[room_id] = room_state | ||||||
| 
 | 
 | ||||||
|         # Run at most 3 of these at once: sync does 10 at a time but email |         # Run at most 3 of these at once: sync does 10 at a time but email | ||||||
|  | |||||||
| @ -17,6 +17,7 @@ import logging | |||||||
| import re | import re | ||||||
| 
 | 
 | ||||||
| from synapse.types import UserID | from synapse.types import UserID | ||||||
|  | from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache | ||||||
| from synapse.util.caches.lrucache import LruCache | from synapse.util.caches.lrucache import LruCache | ||||||
| 
 | 
 | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
| @ -125,6 +126,11 @@ class PushRuleEvaluatorForEvent(object): | |||||||
|         return self._value_cache.get(dotted_key, None) |         return self._value_cache.get(dotted_key, None) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | # Caches (glob, word_boundary) -> regex for push. See _glob_matches | ||||||
|  | regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR) | ||||||
|  | register_cache("regex_push_cache", regex_cache) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def _glob_matches(glob, value, word_boundary=False): | def _glob_matches(glob, value, word_boundary=False): | ||||||
|     """Tests if value matches glob. |     """Tests if value matches glob. | ||||||
| 
 | 
 | ||||||
| @ -137,7 +143,29 @@ def _glob_matches(glob, value, word_boundary=False): | |||||||
|     Returns: |     Returns: | ||||||
|         bool |         bool | ||||||
|     """ |     """ | ||||||
|  | 
 | ||||||
|     try: |     try: | ||||||
|  |         r = regex_cache.get((glob, word_boundary), None) | ||||||
|  |         if not r: | ||||||
|  |             r = _glob_to_re(glob, word_boundary) | ||||||
|  |             regex_cache[(glob, word_boundary)] = r | ||||||
|  |         return r.search(value) | ||||||
|  |     except re.error: | ||||||
|  |         logger.warn("Failed to parse glob to regex: %r", glob) | ||||||
|  |         return False | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def _glob_to_re(glob, word_boundary): | ||||||
|  |     """Generates regex for a given glob. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         glob (string) | ||||||
|  |         word_boundary (bool): Whether to match against word boundaries or entire | ||||||
|  |             string. Defaults to False. | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         regex object | ||||||
|  |     """ | ||||||
|     if IS_GLOB.search(glob): |     if IS_GLOB.search(glob): | ||||||
|         r = re.escape(glob) |         r = re.escape(glob) | ||||||
| 
 | 
 | ||||||
| @ -156,25 +184,20 @@ def _glob_matches(glob, value, word_boundary=False): | |||||||
|         ) |         ) | ||||||
|         if word_boundary: |         if word_boundary: | ||||||
|             r = r"\b%s\b" % (r,) |             r = r"\b%s\b" % (r,) | ||||||
|                 r = _compile_regex(r) |  | ||||||
| 
 | 
 | ||||||
|                 return r.search(value) |             return re.compile(r, flags=re.IGNORECASE) | ||||||
|         else: |         else: | ||||||
|                 r = r + "$" |             r = "^" + r + "$" | ||||||
|                 r = _compile_regex(r) |  | ||||||
| 
 | 
 | ||||||
|                 return r.match(value) |             return re.compile(r, flags=re.IGNORECASE) | ||||||
|     elif word_boundary: |     elif word_boundary: | ||||||
|         r = re.escape(glob) |         r = re.escape(glob) | ||||||
|         r = r"\b%s\b" % (r,) |         r = r"\b%s\b" % (r,) | ||||||
|             r = _compile_regex(r) |  | ||||||
| 
 | 
 | ||||||
|             return r.search(value) |         return re.compile(r, flags=re.IGNORECASE) | ||||||
|     else: |     else: | ||||||
|             return value.lower() == glob.lower() |         r = "^" + re.escape(glob) + "$" | ||||||
|     except re.error: |         return re.compile(r, flags=re.IGNORECASE) | ||||||
|         logger.warn("Failed to parse glob to regex: %r", glob) |  | ||||||
|         return False |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def _flatten_dict(d, prefix=[], result={}): | def _flatten_dict(d, prefix=[], result={}): | ||||||
| @ -185,16 +208,3 @@ def _flatten_dict(d, prefix=[], result={}): | |||||||
|             _flatten_dict(value, prefix=(prefix + [key]), result=result) |             _flatten_dict(value, prefix=(prefix + [key]), result=result) | ||||||
| 
 | 
 | ||||||
|     return result |     return result | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| regex_cache = LruCache(5000) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def _compile_regex(regex_str): |  | ||||||
|     r = regex_cache.get(regex_str, None) |  | ||||||
|     if r: |  | ||||||
|         return r |  | ||||||
| 
 |  | ||||||
|     r = re.compile(regex_str, flags=re.IGNORECASE) |  | ||||||
|     regex_cache[regex_str] = r |  | ||||||
|     return r |  | ||||||
|  | |||||||
| @ -33,13 +33,13 @@ def get_badge_count(store, user_id): | |||||||
| 
 | 
 | ||||||
|     badge = len(invites) |     badge = len(invites) | ||||||
| 
 | 
 | ||||||
|     for r in joins: |     for room_id in joins: | ||||||
|         if r.room_id in my_receipts_by_room: |         if room_id in my_receipts_by_room: | ||||||
|             last_unread_event_id = my_receipts_by_room[r.room_id] |             last_unread_event_id = my_receipts_by_room[room_id] | ||||||
| 
 | 
 | ||||||
|             notifs = yield ( |             notifs = yield ( | ||||||
|                 store.get_unread_event_push_actions_by_room_for_user( |                 store.get_unread_event_push_actions_by_room_for_user( | ||||||
|                     r.room_id, user_id, last_unread_event_id |                     room_id, user_id, last_unread_event_id | ||||||
|                 ) |                 ) | ||||||
|             ) |             ) | ||||||
|             # return one badge count per conversation, as count per |             # return one badge count per conversation, as count per | ||||||
|  | |||||||
| @ -1,4 +1,5 @@ | |||||||
| # Copyright 2015, 2016 OpenMarket Ltd | # Copyright 2015, 2016 OpenMarket Ltd | ||||||
|  | # Copyright 2017 Vector Creations Ltd | ||||||
| # | # | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||||
| @ -18,6 +19,7 @@ from distutils.version import LooseVersion | |||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
| 
 | 
 | ||||||
| REQUIREMENTS = { | REQUIREMENTS = { | ||||||
|  |     "jsonschema>=2.5.1": ["jsonschema>=2.5.1"], | ||||||
|     "frozendict>=0.4": ["frozendict"], |     "frozendict>=0.4": ["frozendict"], | ||||||
|     "unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"], |     "unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"], | ||||||
|     "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"], |     "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"], | ||||||
| @ -37,6 +39,7 @@ REQUIREMENTS = { | |||||||
|     "pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"], |     "pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"], | ||||||
|     "pymacaroons-pynacl": ["pymacaroons"], |     "pymacaroons-pynacl": ["pymacaroons"], | ||||||
|     "msgpack-python>=0.3.0": ["msgpack"], |     "msgpack-python>=0.3.0": ["msgpack"], | ||||||
|  |     "phonenumbers>=8.2.0": ["phonenumbers"], | ||||||
| } | } | ||||||
| CONDITIONAL_REQUIREMENTS = { | CONDITIONAL_REQUIREMENTS = { | ||||||
|     "web_client": { |     "web_client": { | ||||||
|  | |||||||
| @ -283,12 +283,12 @@ class ReplicationResource(Resource): | |||||||
| 
 | 
 | ||||||
|             if request_events != upto_events_token: |             if request_events != upto_events_token: | ||||||
|                 writer.write_header_and_rows("events", res.new_forward_events, ( |                 writer.write_header_and_rows("events", res.new_forward_events, ( | ||||||
|                     "position", "internal", "json", "state_group" |                     "position", "event_id", "room_id", "type", "state_key", | ||||||
|                 ), position=upto_events_token) |                 ), position=upto_events_token) | ||||||
| 
 | 
 | ||||||
|             if request_backfill != upto_backfill_token: |             if request_backfill != upto_backfill_token: | ||||||
|                 writer.write_header_and_rows("backfill", res.new_backfill_events, ( |                 writer.write_header_and_rows("backfill", res.new_backfill_events, ( | ||||||
|                     "position", "internal", "json", "state_group", |                     "position", "event_id", "room_id", "type", "state_key", "redacts", | ||||||
|                 ), position=upto_backfill_token) |                 ), position=upto_backfill_token) | ||||||
| 
 | 
 | ||||||
|             writer.write_header_and_rows( |             writer.write_header_and_rows( | ||||||
|  | |||||||
| @ -27,4 +27,9 @@ class SlavedIdTracker(object): | |||||||
|         self._current = (max if self.step > 0 else min)(self._current, new_id) |         self._current = (max if self.step > 0 else min)(self._current, new_id) | ||||||
| 
 | 
 | ||||||
|     def get_current_token(self): |     def get_current_token(self): | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             int | ||||||
|  |         """ | ||||||
|         return self._current |         return self._current | ||||||
|  | |||||||
| @ -16,7 +16,6 @@ from ._base import BaseSlavedStore | |||||||
| from ._slaved_id_tracker import SlavedIdTracker | from ._slaved_id_tracker import SlavedIdTracker | ||||||
| 
 | 
 | ||||||
| from synapse.api.constants import EventTypes | from synapse.api.constants import EventTypes | ||||||
| from synapse.events import FrozenEvent |  | ||||||
| from synapse.storage import DataStore | from synapse.storage import DataStore | ||||||
| from synapse.storage.roommember import RoomMemberStore | from synapse.storage.roommember import RoomMemberStore | ||||||
| from synapse.storage.event_federation import EventFederationStore | from synapse.storage.event_federation import EventFederationStore | ||||||
| @ -25,7 +24,6 @@ from synapse.storage.state import StateStore | |||||||
| from synapse.storage.stream import StreamStore | from synapse.storage.stream import StreamStore | ||||||
| from synapse.util.caches.stream_change_cache import StreamChangeCache | from synapse.util.caches.stream_change_cache import StreamChangeCache | ||||||
| 
 | 
 | ||||||
| import ujson as json |  | ||||||
| import logging | import logging | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -109,6 +107,10 @@ class SlavedEventStore(BaseSlavedStore): | |||||||
|     get_recent_event_ids_for_room = ( |     get_recent_event_ids_for_room = ( | ||||||
|         StreamStore.__dict__["get_recent_event_ids_for_room"] |         StreamStore.__dict__["get_recent_event_ids_for_room"] | ||||||
|     ) |     ) | ||||||
|  |     get_current_state_ids = ( | ||||||
|  |         StateStore.__dict__["get_current_state_ids"] | ||||||
|  |     ) | ||||||
|  |     has_room_changed_since = DataStore.has_room_changed_since.__func__ | ||||||
| 
 | 
 | ||||||
|     get_unread_push_actions_for_user_in_range_for_http = ( |     get_unread_push_actions_for_user_in_range_for_http = ( | ||||||
|         DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__ |         DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__ | ||||||
| @ -165,7 +167,6 @@ class SlavedEventStore(BaseSlavedStore): | |||||||
|     _get_rooms_for_user_where_membership_is_txn = ( |     _get_rooms_for_user_where_membership_is_txn = ( | ||||||
|         DataStore._get_rooms_for_user_where_membership_is_txn.__func__ |         DataStore._get_rooms_for_user_where_membership_is_txn.__func__ | ||||||
|     ) |     ) | ||||||
|     _get_members_rows_txn = DataStore._get_members_rows_txn.__func__ |  | ||||||
|     _get_state_for_groups = DataStore._get_state_for_groups.__func__ |     _get_state_for_groups = DataStore._get_state_for_groups.__func__ | ||||||
|     _get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__ |     _get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__ | ||||||
|     _get_events_around_txn = DataStore._get_events_around_txn.__func__ |     _get_events_around_txn = DataStore._get_events_around_txn.__func__ | ||||||
| @ -238,46 +239,32 @@ class SlavedEventStore(BaseSlavedStore): | |||||||
|         return super(SlavedEventStore, self).process_replication(result) |         return super(SlavedEventStore, self).process_replication(result) | ||||||
| 
 | 
 | ||||||
|     def _process_replication_row(self, row, backfilled): |     def _process_replication_row(self, row, backfilled): | ||||||
|         internal = json.loads(row[1]) |         stream_ordering = row[0] if not backfilled else -row[0] | ||||||
|         event_json = json.loads(row[2]) |  | ||||||
|         event = FrozenEvent(event_json, internal_metadata_dict=internal) |  | ||||||
|         self.invalidate_caches_for_event( |         self.invalidate_caches_for_event( | ||||||
|             event, backfilled, |             stream_ordering, row[1], row[2], row[3], row[4], row[5], | ||||||
|  |             backfilled=backfilled, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     def invalidate_caches_for_event(self, event, backfilled): |     def invalidate_caches_for_event(self, stream_ordering, event_id, room_id, | ||||||
|         self._invalidate_get_event_cache(event.event_id) |                                     etype, state_key, redacts, backfilled): | ||||||
|  |         self._invalidate_get_event_cache(event_id) | ||||||
| 
 | 
 | ||||||
|         self.get_latest_event_ids_in_room.invalidate((event.room_id,)) |         self.get_latest_event_ids_in_room.invalidate((room_id,)) | ||||||
| 
 | 
 | ||||||
|         self.get_unread_event_push_actions_by_room_for_user.invalidate_many( |         self.get_unread_event_push_actions_by_room_for_user.invalidate_many( | ||||||
|             (event.room_id,) |             (room_id,) | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         if not backfilled: |         if not backfilled: | ||||||
|             self._events_stream_cache.entity_has_changed( |             self._events_stream_cache.entity_has_changed( | ||||||
|                 event.room_id, event.internal_metadata.stream_ordering |                 room_id, stream_ordering | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         # self.get_unread_event_push_actions_by_room_for_user.invalidate_many( |         if redacts: | ||||||
|         #     (event.room_id,) |             self._invalidate_get_event_cache(redacts) | ||||||
|         # ) |  | ||||||
| 
 | 
 | ||||||
|         if event.type == EventTypes.Redaction: |         if etype == EventTypes.Member: | ||||||
|             self._invalidate_get_event_cache(event.redacts) |  | ||||||
| 
 |  | ||||||
|         if event.type == EventTypes.Member: |  | ||||||
|             self._membership_stream_cache.entity_has_changed( |             self._membership_stream_cache.entity_has_changed( | ||||||
|                 event.state_key, event.internal_metadata.stream_ordering |                 state_key, stream_ordering | ||||||
|             ) |             ) | ||||||
|             self.get_invited_rooms_for_user.invalidate((event.state_key,)) |             self.get_invited_rooms_for_user.invalidate((state_key,)) | ||||||
| 
 |  | ||||||
|         if not event.is_state(): |  | ||||||
|             return |  | ||||||
| 
 |  | ||||||
|         if backfilled: |  | ||||||
|             return |  | ||||||
| 
 |  | ||||||
|         if (not event.internal_metadata.is_invite_from_remote() |  | ||||||
|                 and event.internal_metadata.is_outlier()): |  | ||||||
|             return |  | ||||||
|  | |||||||
| @ -57,5 +57,6 @@ class SlavedPresenceStore(BaseSlavedStore): | |||||||
|                 self.presence_stream_cache.entity_has_changed( |                 self.presence_stream_cache.entity_has_changed( | ||||||
|                     user_id, position |                     user_id, position | ||||||
|                 ) |                 ) | ||||||
|  |                 self._get_presence_for_user.invalidate((user_id,)) | ||||||
| 
 | 
 | ||||||
|         return super(SlavedPresenceStore, self).process_replication(result) |         return super(SlavedPresenceStore, self).process_replication(result) | ||||||
|  | |||||||
| @ -19,6 +19,7 @@ from synapse.api.errors import SynapseError, LoginError, Codes | |||||||
| from synapse.types import UserID | from synapse.types import UserID | ||||||
| from synapse.http.server import finish_request | from synapse.http.server import finish_request | ||||||
| from synapse.http.servlet import parse_json_object_from_request | from synapse.http.servlet import parse_json_object_from_request | ||||||
|  | from synapse.util.msisdn import phone_number_to_msisdn | ||||||
| 
 | 
 | ||||||
| from .base import ClientV1RestServlet, client_path_patterns | from .base import ClientV1RestServlet, client_path_patterns | ||||||
| 
 | 
 | ||||||
| @ -33,10 +34,55 @@ from saml2.client import Saml2Client | |||||||
| 
 | 
 | ||||||
| import xml.etree.ElementTree as ET | import xml.etree.ElementTree as ET | ||||||
| 
 | 
 | ||||||
|  | from twisted.web.client import PartialDownloadError | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def login_submission_legacy_convert(submission): | ||||||
|  |     """ | ||||||
|  |     If the input login submission is an old style object | ||||||
|  |     (ie. with top-level user / medium / address) convert it | ||||||
|  |     to a typed object. | ||||||
|  |     """ | ||||||
|  |     if "user" in submission: | ||||||
|  |         submission["identifier"] = { | ||||||
|  |             "type": "m.id.user", | ||||||
|  |             "user": submission["user"], | ||||||
|  |         } | ||||||
|  |         del submission["user"] | ||||||
|  | 
 | ||||||
|  |     if "medium" in submission and "address" in submission: | ||||||
|  |         submission["identifier"] = { | ||||||
|  |             "type": "m.id.thirdparty", | ||||||
|  |             "medium": submission["medium"], | ||||||
|  |             "address": submission["address"], | ||||||
|  |         } | ||||||
|  |         del submission["medium"] | ||||||
|  |         del submission["address"] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def login_id_thirdparty_from_phone(identifier): | ||||||
|  |     """ | ||||||
|  |     Convert a phone login identifier type to a generic threepid identifier | ||||||
|  |     Args: | ||||||
|  |         identifier(dict): Login identifier dict of type 'm.id.phone' | ||||||
|  | 
 | ||||||
|  |     Returns: Login identifier dict of type 'm.id.threepid' | ||||||
|  |     """ | ||||||
|  |     if "country" not in identifier or "number" not in identifier: | ||||||
|  |         raise SynapseError(400, "Invalid phone-type identifier") | ||||||
|  | 
 | ||||||
|  |     msisdn = phone_number_to_msisdn(identifier["country"], identifier["number"]) | ||||||
|  | 
 | ||||||
|  |     return { | ||||||
|  |         "type": "m.id.thirdparty", | ||||||
|  |         "medium": "msisdn", | ||||||
|  |         "address": msisdn, | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class LoginRestServlet(ClientV1RestServlet): | class LoginRestServlet(ClientV1RestServlet): | ||||||
|     PATTERNS = client_path_patterns("/login$") |     PATTERNS = client_path_patterns("/login$") | ||||||
|     PASS_TYPE = "m.login.password" |     PASS_TYPE = "m.login.password" | ||||||
| @ -117,20 +163,52 @@ class LoginRestServlet(ClientV1RestServlet): | |||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def do_password_login(self, login_submission): |     def do_password_login(self, login_submission): | ||||||
|         if 'medium' in login_submission and 'address' in login_submission: |         if "password" not in login_submission: | ||||||
|             address = login_submission['address'] |             raise SynapseError(400, "Missing parameter: password") | ||||||
|             if login_submission['medium'] == 'email': | 
 | ||||||
|  |         login_submission_legacy_convert(login_submission) | ||||||
|  | 
 | ||||||
|  |         if "identifier" not in login_submission: | ||||||
|  |             raise SynapseError(400, "Missing param: identifier") | ||||||
|  | 
 | ||||||
|  |         identifier = login_submission["identifier"] | ||||||
|  |         if "type" not in identifier: | ||||||
|  |             raise SynapseError(400, "Login identifier has no type") | ||||||
|  | 
 | ||||||
|  |         # convert phone type identifiers to generic threepids | ||||||
|  |         if identifier["type"] == "m.id.phone": | ||||||
|  |             identifier = login_id_thirdparty_from_phone(identifier) | ||||||
|  | 
 | ||||||
|  |         # convert threepid identifiers to user IDs | ||||||
|  |         if identifier["type"] == "m.id.thirdparty": | ||||||
|  |             if 'medium' not in identifier or 'address' not in identifier: | ||||||
|  |                 raise SynapseError(400, "Invalid thirdparty identifier") | ||||||
|  | 
 | ||||||
|  |             address = identifier['address'] | ||||||
|  |             if identifier['medium'] == 'email': | ||||||
|                 # For emails, transform the address to lowercase. |                 # For emails, transform the address to lowercase. | ||||||
|                 # We store all email addreses as lowercase in the DB. |                 # We store all email addreses as lowercase in the DB. | ||||||
|                 # (See add_threepid in synapse/handlers/auth.py) |                 # (See add_threepid in synapse/handlers/auth.py) | ||||||
|                 address = address.lower() |                 address = address.lower() | ||||||
|             user_id = yield self.hs.get_datastore().get_user_id_by_threepid( |             user_id = yield self.hs.get_datastore().get_user_id_by_threepid( | ||||||
|                 login_submission['medium'], address |                 identifier['medium'], address | ||||||
|             ) |             ) | ||||||
|             if not user_id: |             if not user_id: | ||||||
|                 raise LoginError(403, "", errcode=Codes.FORBIDDEN) |                 raise LoginError(403, "", errcode=Codes.FORBIDDEN) | ||||||
|         else: | 
 | ||||||
|             user_id = login_submission['user'] |             identifier = { | ||||||
|  |                 "type": "m.id.user", | ||||||
|  |                 "user": user_id, | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |         # by this point, the identifier should be an m.id.user: if it's anything | ||||||
|  |         # else, we haven't understood it. | ||||||
|  |         if identifier["type"] != "m.id.user": | ||||||
|  |             raise SynapseError(400, "Unknown login identifier type") | ||||||
|  |         if "user" not in identifier: | ||||||
|  |             raise SynapseError(400, "User identifier is missing 'user' key") | ||||||
|  | 
 | ||||||
|  |         user_id = identifier["user"] | ||||||
| 
 | 
 | ||||||
|         if not user_id.startswith('@'): |         if not user_id.startswith('@'): | ||||||
|             user_id = UserID.create( |             user_id = UserID.create( | ||||||
| @ -341,7 +419,12 @@ class CasTicketServlet(ClientV1RestServlet): | |||||||
|             "ticket": request.args["ticket"], |             "ticket": request.args["ticket"], | ||||||
|             "service": self.cas_service_url |             "service": self.cas_service_url | ||||||
|         } |         } | ||||||
|  |         try: | ||||||
|             body = yield http_client.get_raw(uri, args) |             body = yield http_client.get_raw(uri, args) | ||||||
|  |         except PartialDownloadError as pde: | ||||||
|  |             # Twisted raises this error if the connection is closed, | ||||||
|  |             # even if that's being used old-http style to signal end-of-data | ||||||
|  |             body = pde.response | ||||||
|         result = yield self.handle_cas_response(request, body, client_redirect_url) |         result = yield self.handle_cas_response(request, body, client_redirect_url) | ||||||
|         defer.returnValue(result) |         defer.returnValue(result) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -19,6 +19,7 @@ from twisted.internet import defer | |||||||
| 
 | 
 | ||||||
| from synapse.api.errors import SynapseError, AuthError | from synapse.api.errors import SynapseError, AuthError | ||||||
| from synapse.types import UserID | from synapse.types import UserID | ||||||
|  | from synapse.handlers.presence import format_user_presence_state | ||||||
| from synapse.http.servlet import parse_json_object_from_request | from synapse.http.servlet import parse_json_object_from_request | ||||||
| from .base import ClientV1RestServlet, client_path_patterns | from .base import ClientV1RestServlet, client_path_patterns | ||||||
| 
 | 
 | ||||||
| @ -33,6 +34,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): | |||||||
|     def __init__(self, hs): |     def __init__(self, hs): | ||||||
|         super(PresenceStatusRestServlet, self).__init__(hs) |         super(PresenceStatusRestServlet, self).__init__(hs) | ||||||
|         self.presence_handler = hs.get_presence_handler() |         self.presence_handler = hs.get_presence_handler() | ||||||
|  |         self.clock = hs.get_clock() | ||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def on_GET(self, request, user_id): |     def on_GET(self, request, user_id): | ||||||
| @ -48,6 +50,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): | |||||||
|                 raise AuthError(403, "You are not allowed to see their presence.") |                 raise AuthError(403, "You are not allowed to see their presence.") | ||||||
| 
 | 
 | ||||||
|         state = yield self.presence_handler.get_state(target_user=user) |         state = yield self.presence_handler.get_state(target_user=user) | ||||||
|  |         state = format_user_presence_state(state, self.clock.time_msec()) | ||||||
| 
 | 
 | ||||||
|         defer.returnValue((200, state)) |         defer.returnValue((200, state)) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -748,8 +748,7 @@ class JoinedRoomsRestServlet(ClientV1RestServlet): | |||||||
|     def on_GET(self, request): |     def on_GET(self, request): | ||||||
|         requester = yield self.auth.get_user_by_req(request, allow_guest=True) |         requester = yield self.auth.get_user_by_req(request, allow_guest=True) | ||||||
| 
 | 
 | ||||||
|         rooms = yield self.store.get_rooms_for_user(requester.user.to_string()) |         room_ids = yield self.store.get_rooms_for_user(requester.user.to_string()) | ||||||
|         room_ids = set(r.room_id for r in rooms)  # Ensure they're unique. |  | ||||||
|         defer.returnValue((200, {"joined_rooms": list(room_ids)})) |         defer.returnValue((200, {"joined_rooms": list(room_ids)})) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -1,5 +1,6 @@ | |||||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||||
| # Copyright 2015, 2016 OpenMarket Ltd | # Copyright 2015, 2016 OpenMarket Ltd | ||||||
|  | # Copyright 2017 Vector Creations Ltd | ||||||
| # | # | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||||
| @ -17,8 +18,11 @@ from twisted.internet import defer | |||||||
| 
 | 
 | ||||||
| from synapse.api.constants import LoginType | from synapse.api.constants import LoginType | ||||||
| from synapse.api.errors import LoginError, SynapseError, Codes | from synapse.api.errors import LoginError, SynapseError, Codes | ||||||
| from synapse.http.servlet import RestServlet, parse_json_object_from_request | from synapse.http.servlet import ( | ||||||
|  |     RestServlet, parse_json_object_from_request, assert_params_in_request | ||||||
|  | ) | ||||||
| from synapse.util.async import run_on_reactor | from synapse.util.async import run_on_reactor | ||||||
|  | from synapse.util.msisdn import phone_number_to_msisdn | ||||||
| 
 | 
 | ||||||
| from ._base import client_v2_patterns | from ._base import client_v2_patterns | ||||||
| 
 | 
 | ||||||
| @ -28,11 +32,11 @@ import logging | |||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class PasswordRequestTokenRestServlet(RestServlet): | class EmailPasswordRequestTokenRestServlet(RestServlet): | ||||||
|     PATTERNS = client_v2_patterns("/account/password/email/requestToken$") |     PATTERNS = client_v2_patterns("/account/password/email/requestToken$") | ||||||
| 
 | 
 | ||||||
|     def __init__(self, hs): |     def __init__(self, hs): | ||||||
|         super(PasswordRequestTokenRestServlet, self).__init__() |         super(EmailPasswordRequestTokenRestServlet, self).__init__() | ||||||
|         self.hs = hs |         self.hs = hs | ||||||
|         self.identity_handler = hs.get_handlers().identity_handler |         self.identity_handler = hs.get_handlers().identity_handler | ||||||
| 
 | 
 | ||||||
| @ -40,14 +44,9 @@ class PasswordRequestTokenRestServlet(RestServlet): | |||||||
|     def on_POST(self, request): |     def on_POST(self, request): | ||||||
|         body = parse_json_object_from_request(request) |         body = parse_json_object_from_request(request) | ||||||
| 
 | 
 | ||||||
|         required = ['id_server', 'client_secret', 'email', 'send_attempt'] |         assert_params_in_request(body, [ | ||||||
|         absent = [] |             'id_server', 'client_secret', 'email', 'send_attempt' | ||||||
|         for k in required: |         ]) | ||||||
|             if k not in body: |  | ||||||
|                 absent.append(k) |  | ||||||
| 
 |  | ||||||
|         if absent: |  | ||||||
|             raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) |  | ||||||
| 
 | 
 | ||||||
|         existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( |         existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( | ||||||
|             'email', body['email'] |             'email', body['email'] | ||||||
| @ -60,6 +59,37 @@ class PasswordRequestTokenRestServlet(RestServlet): | |||||||
|         defer.returnValue((200, ret)) |         defer.returnValue((200, ret)) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class MsisdnPasswordRequestTokenRestServlet(RestServlet): | ||||||
|  |     PATTERNS = client_v2_patterns("/account/password/msisdn/requestToken$") | ||||||
|  | 
 | ||||||
|  |     def __init__(self, hs): | ||||||
|  |         super(MsisdnPasswordRequestTokenRestServlet, self).__init__() | ||||||
|  |         self.hs = hs | ||||||
|  |         self.datastore = self.hs.get_datastore() | ||||||
|  |         self.identity_handler = hs.get_handlers().identity_handler | ||||||
|  | 
 | ||||||
|  |     @defer.inlineCallbacks | ||||||
|  |     def on_POST(self, request): | ||||||
|  |         body = parse_json_object_from_request(request) | ||||||
|  | 
 | ||||||
|  |         assert_params_in_request(body, [ | ||||||
|  |             'id_server', 'client_secret', | ||||||
|  |             'country', 'phone_number', 'send_attempt', | ||||||
|  |         ]) | ||||||
|  | 
 | ||||||
|  |         msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) | ||||||
|  | 
 | ||||||
|  |         existingUid = yield self.datastore.get_user_id_by_threepid( | ||||||
|  |             'msisdn', msisdn | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         if existingUid is None: | ||||||
|  |             raise SynapseError(400, "MSISDN not found", Codes.THREEPID_NOT_FOUND) | ||||||
|  | 
 | ||||||
|  |         ret = yield self.identity_handler.requestMsisdnToken(**body) | ||||||
|  |         defer.returnValue((200, ret)) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class PasswordRestServlet(RestServlet): | class PasswordRestServlet(RestServlet): | ||||||
|     PATTERNS = client_v2_patterns("/account/password$") |     PATTERNS = client_v2_patterns("/account/password$") | ||||||
| 
 | 
 | ||||||
| @ -68,6 +98,7 @@ class PasswordRestServlet(RestServlet): | |||||||
|         self.hs = hs |         self.hs = hs | ||||||
|         self.auth = hs.get_auth() |         self.auth = hs.get_auth() | ||||||
|         self.auth_handler = hs.get_auth_handler() |         self.auth_handler = hs.get_auth_handler() | ||||||
|  |         self.datastore = self.hs.get_datastore() | ||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def on_POST(self, request): |     def on_POST(self, request): | ||||||
| @ -77,7 +108,8 @@ class PasswordRestServlet(RestServlet): | |||||||
| 
 | 
 | ||||||
|         authed, result, params, _ = yield self.auth_handler.check_auth([ |         authed, result, params, _ = yield self.auth_handler.check_auth([ | ||||||
|             [LoginType.PASSWORD], |             [LoginType.PASSWORD], | ||||||
|             [LoginType.EMAIL_IDENTITY] |             [LoginType.EMAIL_IDENTITY], | ||||||
|  |             [LoginType.MSISDN], | ||||||
|         ], body, self.hs.get_ip_from_request(request)) |         ], body, self.hs.get_ip_from_request(request)) | ||||||
| 
 | 
 | ||||||
|         if not authed: |         if not authed: | ||||||
| @ -102,7 +134,7 @@ class PasswordRestServlet(RestServlet): | |||||||
|                 # (See add_threepid in synapse/handlers/auth.py) |                 # (See add_threepid in synapse/handlers/auth.py) | ||||||
|                 threepid['address'] = threepid['address'].lower() |                 threepid['address'] = threepid['address'].lower() | ||||||
|             # if using email, we must know about the email they're authing with! |             # if using email, we must know about the email they're authing with! | ||||||
|             threepid_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( |             threepid_user_id = yield self.datastore.get_user_id_by_threepid( | ||||||
|                 threepid['medium'], threepid['address'] |                 threepid['medium'], threepid['address'] | ||||||
|             ) |             ) | ||||||
|             if not threepid_user_id: |             if not threepid_user_id: | ||||||
| @ -169,13 +201,14 @@ class DeactivateAccountRestServlet(RestServlet): | |||||||
|         defer.returnValue((200, {})) |         defer.returnValue((200, {})) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class ThreepidRequestTokenRestServlet(RestServlet): | class EmailThreepidRequestTokenRestServlet(RestServlet): | ||||||
|     PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$") |     PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$") | ||||||
| 
 | 
 | ||||||
|     def __init__(self, hs): |     def __init__(self, hs): | ||||||
|         self.hs = hs |         self.hs = hs | ||||||
|         super(ThreepidRequestTokenRestServlet, self).__init__() |         super(EmailThreepidRequestTokenRestServlet, self).__init__() | ||||||
|         self.identity_handler = hs.get_handlers().identity_handler |         self.identity_handler = hs.get_handlers().identity_handler | ||||||
|  |         self.datastore = self.hs.get_datastore() | ||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def on_POST(self, request): |     def on_POST(self, request): | ||||||
| @ -190,7 +223,7 @@ class ThreepidRequestTokenRestServlet(RestServlet): | |||||||
|         if absent: |         if absent: | ||||||
|             raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) |             raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) | ||||||
| 
 | 
 | ||||||
|         existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( |         existingUid = yield self.datastore.get_user_id_by_threepid( | ||||||
|             'email', body['email'] |             'email', body['email'] | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
| @ -201,6 +234,44 @@ class ThreepidRequestTokenRestServlet(RestServlet): | |||||||
|         defer.returnValue((200, ret)) |         defer.returnValue((200, ret)) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class MsisdnThreepidRequestTokenRestServlet(RestServlet): | ||||||
|  |     PATTERNS = client_v2_patterns("/account/3pid/msisdn/requestToken$") | ||||||
|  | 
 | ||||||
|  |     def __init__(self, hs): | ||||||
|  |         self.hs = hs | ||||||
|  |         super(MsisdnThreepidRequestTokenRestServlet, self).__init__() | ||||||
|  |         self.identity_handler = hs.get_handlers().identity_handler | ||||||
|  |         self.datastore = self.hs.get_datastore() | ||||||
|  | 
 | ||||||
|  |     @defer.inlineCallbacks | ||||||
|  |     def on_POST(self, request): | ||||||
|  |         body = parse_json_object_from_request(request) | ||||||
|  | 
 | ||||||
|  |         required = [ | ||||||
|  |             'id_server', 'client_secret', | ||||||
|  |             'country', 'phone_number', 'send_attempt', | ||||||
|  |         ] | ||||||
|  |         absent = [] | ||||||
|  |         for k in required: | ||||||
|  |             if k not in body: | ||||||
|  |                 absent.append(k) | ||||||
|  | 
 | ||||||
|  |         if absent: | ||||||
|  |             raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) | ||||||
|  | 
 | ||||||
|  |         msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) | ||||||
|  | 
 | ||||||
|  |         existingUid = yield self.datastore.get_user_id_by_threepid( | ||||||
|  |             'msisdn', msisdn | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         if existingUid is not None: | ||||||
|  |             raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) | ||||||
|  | 
 | ||||||
|  |         ret = yield self.identity_handler.requestMsisdnToken(**body) | ||||||
|  |         defer.returnValue((200, ret)) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class ThreepidRestServlet(RestServlet): | class ThreepidRestServlet(RestServlet): | ||||||
|     PATTERNS = client_v2_patterns("/account/3pid$") |     PATTERNS = client_v2_patterns("/account/3pid$") | ||||||
| 
 | 
 | ||||||
| @ -210,6 +281,7 @@ class ThreepidRestServlet(RestServlet): | |||||||
|         self.identity_handler = hs.get_handlers().identity_handler |         self.identity_handler = hs.get_handlers().identity_handler | ||||||
|         self.auth = hs.get_auth() |         self.auth = hs.get_auth() | ||||||
|         self.auth_handler = hs.get_auth_handler() |         self.auth_handler = hs.get_auth_handler() | ||||||
|  |         self.datastore = self.hs.get_datastore() | ||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def on_GET(self, request): |     def on_GET(self, request): | ||||||
| @ -217,7 +289,7 @@ class ThreepidRestServlet(RestServlet): | |||||||
| 
 | 
 | ||||||
|         requester = yield self.auth.get_user_by_req(request) |         requester = yield self.auth.get_user_by_req(request) | ||||||
| 
 | 
 | ||||||
|         threepids = yield self.hs.get_datastore().user_get_threepids( |         threepids = yield self.datastore.user_get_threepids( | ||||||
|             requester.user.to_string() |             requester.user.to_string() | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
| @ -258,7 +330,7 @@ class ThreepidRestServlet(RestServlet): | |||||||
| 
 | 
 | ||||||
|         if 'bind' in body and body['bind']: |         if 'bind' in body and body['bind']: | ||||||
|             logger.debug( |             logger.debug( | ||||||
|                 "Binding emails %s to %s", |                 "Binding threepid %s to %s", | ||||||
|                 threepid, user_id |                 threepid, user_id | ||||||
|             ) |             ) | ||||||
|             yield self.identity_handler.bind_threepid( |             yield self.identity_handler.bind_threepid( | ||||||
| @ -302,9 +374,11 @@ class ThreepidDeleteRestServlet(RestServlet): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def register_servlets(hs, http_server): | def register_servlets(hs, http_server): | ||||||
|     PasswordRequestTokenRestServlet(hs).register(http_server) |     EmailPasswordRequestTokenRestServlet(hs).register(http_server) | ||||||
|  |     MsisdnPasswordRequestTokenRestServlet(hs).register(http_server) | ||||||
|     PasswordRestServlet(hs).register(http_server) |     PasswordRestServlet(hs).register(http_server) | ||||||
|     DeactivateAccountRestServlet(hs).register(http_server) |     DeactivateAccountRestServlet(hs).register(http_server) | ||||||
|     ThreepidRequestTokenRestServlet(hs).register(http_server) |     EmailThreepidRequestTokenRestServlet(hs).register(http_server) | ||||||
|  |     MsisdnThreepidRequestTokenRestServlet(hs).register(http_server) | ||||||
|     ThreepidRestServlet(hs).register(http_server) |     ThreepidRestServlet(hs).register(http_server) | ||||||
|     ThreepidDeleteRestServlet(hs).register(http_server) |     ThreepidDeleteRestServlet(hs).register(http_server) | ||||||
|  | |||||||
| @ -46,6 +46,52 @@ class DevicesRestServlet(servlet.RestServlet): | |||||||
|         defer.returnValue((200, {"devices": devices})) |         defer.returnValue((200, {"devices": devices})) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class DeleteDevicesRestServlet(servlet.RestServlet): | ||||||
|  |     """ | ||||||
|  |     API for bulk deletion of devices. Accepts a JSON object with a devices | ||||||
|  |     key which lists the device_ids to delete. Requires user interactive auth. | ||||||
|  |     """ | ||||||
|  |     PATTERNS = client_v2_patterns("/delete_devices", releases=[], v2_alpha=False) | ||||||
|  | 
 | ||||||
|  |     def __init__(self, hs): | ||||||
|  |         super(DeleteDevicesRestServlet, self).__init__() | ||||||
|  |         self.hs = hs | ||||||
|  |         self.auth = hs.get_auth() | ||||||
|  |         self.device_handler = hs.get_device_handler() | ||||||
|  |         self.auth_handler = hs.get_auth_handler() | ||||||
|  | 
 | ||||||
|  |     @defer.inlineCallbacks | ||||||
|  |     def on_POST(self, request): | ||||||
|  |         try: | ||||||
|  |             body = servlet.parse_json_object_from_request(request) | ||||||
|  |         except errors.SynapseError as e: | ||||||
|  |             if e.errcode == errors.Codes.NOT_JSON: | ||||||
|  |                 # deal with older clients which didn't pass a J*DELETESON dict | ||||||
|  |                 # the same as those that pass an empty dict | ||||||
|  |                 body = {} | ||||||
|  |             else: | ||||||
|  |                 raise e | ||||||
|  | 
 | ||||||
|  |         if 'devices' not in body: | ||||||
|  |             raise errors.SynapseError( | ||||||
|  |                 400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         authed, result, params, _ = yield self.auth_handler.check_auth([ | ||||||
|  |             [constants.LoginType.PASSWORD], | ||||||
|  |         ], body, self.hs.get_ip_from_request(request)) | ||||||
|  | 
 | ||||||
|  |         if not authed: | ||||||
|  |             defer.returnValue((401, result)) | ||||||
|  | 
 | ||||||
|  |         requester = yield self.auth.get_user_by_req(request) | ||||||
|  |         yield self.device_handler.delete_devices( | ||||||
|  |             requester.user.to_string(), | ||||||
|  |             body['devices'], | ||||||
|  |         ) | ||||||
|  |         defer.returnValue((200, {})) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class DeviceRestServlet(servlet.RestServlet): | class DeviceRestServlet(servlet.RestServlet): | ||||||
|     PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", |     PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", | ||||||
|                                   releases=[], v2_alpha=False) |                                   releases=[], v2_alpha=False) | ||||||
| @ -111,5 +157,6 @@ class DeviceRestServlet(servlet.RestServlet): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def register_servlets(hs, http_server): | def register_servlets(hs, http_server): | ||||||
|  |     DeleteDevicesRestServlet(hs).register(http_server) | ||||||
|     DevicesRestServlet(hs).register(http_server) |     DevicesRestServlet(hs).register(http_server) | ||||||
|     DeviceRestServlet(hs).register(http_server) |     DeviceRestServlet(hs).register(http_server) | ||||||
|  | |||||||
| @ -1,5 +1,6 @@ | |||||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||||
| # Copyright 2015 - 2016 OpenMarket Ltd | # Copyright 2015 - 2016 OpenMarket Ltd | ||||||
|  | # Copyright 2017 Vector Creations Ltd | ||||||
| # | # | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||||
| @ -19,7 +20,10 @@ import synapse | |||||||
| from synapse.api.auth import get_access_token_from_request, has_access_token | from synapse.api.auth import get_access_token_from_request, has_access_token | ||||||
| from synapse.api.constants import LoginType | from synapse.api.constants import LoginType | ||||||
| from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError | from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError | ||||||
| from synapse.http.servlet import RestServlet, parse_json_object_from_request | from synapse.http.servlet import ( | ||||||
|  |     RestServlet, parse_json_object_from_request, assert_params_in_request | ||||||
|  | ) | ||||||
|  | from synapse.util.msisdn import phone_number_to_msisdn | ||||||
| 
 | 
 | ||||||
| from ._base import client_v2_patterns | from ._base import client_v2_patterns | ||||||
| 
 | 
 | ||||||
| @ -43,7 +47,7 @@ else: | |||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class RegisterRequestTokenRestServlet(RestServlet): | class EmailRegisterRequestTokenRestServlet(RestServlet): | ||||||
|     PATTERNS = client_v2_patterns("/register/email/requestToken$") |     PATTERNS = client_v2_patterns("/register/email/requestToken$") | ||||||
| 
 | 
 | ||||||
|     def __init__(self, hs): |     def __init__(self, hs): | ||||||
| @ -51,7 +55,7 @@ class RegisterRequestTokenRestServlet(RestServlet): | |||||||
|         Args: |         Args: | ||||||
|             hs (synapse.server.HomeServer): server |             hs (synapse.server.HomeServer): server | ||||||
|         """ |         """ | ||||||
|         super(RegisterRequestTokenRestServlet, self).__init__() |         super(EmailRegisterRequestTokenRestServlet, self).__init__() | ||||||
|         self.hs = hs |         self.hs = hs | ||||||
|         self.identity_handler = hs.get_handlers().identity_handler |         self.identity_handler = hs.get_handlers().identity_handler | ||||||
| 
 | 
 | ||||||
| @ -59,14 +63,9 @@ class RegisterRequestTokenRestServlet(RestServlet): | |||||||
|     def on_POST(self, request): |     def on_POST(self, request): | ||||||
|         body = parse_json_object_from_request(request) |         body = parse_json_object_from_request(request) | ||||||
| 
 | 
 | ||||||
|         required = ['id_server', 'client_secret', 'email', 'send_attempt'] |         assert_params_in_request(body, [ | ||||||
|         absent = [] |             'id_server', 'client_secret', 'email', 'send_attempt' | ||||||
|         for k in required: |         ]) | ||||||
|             if k not in body: |  | ||||||
|                 absent.append(k) |  | ||||||
| 
 |  | ||||||
|         if len(absent) > 0: |  | ||||||
|             raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) |  | ||||||
| 
 | 
 | ||||||
|         existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( |         existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( | ||||||
|             'email', body['email'] |             'email', body['email'] | ||||||
| @ -79,6 +78,43 @@ class RegisterRequestTokenRestServlet(RestServlet): | |||||||
|         defer.returnValue((200, ret)) |         defer.returnValue((200, ret)) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class MsisdnRegisterRequestTokenRestServlet(RestServlet): | ||||||
|  |     PATTERNS = client_v2_patterns("/register/msisdn/requestToken$") | ||||||
|  | 
 | ||||||
|  |     def __init__(self, hs): | ||||||
|  |         """ | ||||||
|  |         Args: | ||||||
|  |             hs (synapse.server.HomeServer): server | ||||||
|  |         """ | ||||||
|  |         super(MsisdnRegisterRequestTokenRestServlet, self).__init__() | ||||||
|  |         self.hs = hs | ||||||
|  |         self.identity_handler = hs.get_handlers().identity_handler | ||||||
|  | 
 | ||||||
|  |     @defer.inlineCallbacks | ||||||
|  |     def on_POST(self, request): | ||||||
|  |         body = parse_json_object_from_request(request) | ||||||
|  | 
 | ||||||
|  |         assert_params_in_request(body, [ | ||||||
|  |             'id_server', 'client_secret', | ||||||
|  |             'country', 'phone_number', | ||||||
|  |             'send_attempt', | ||||||
|  |         ]) | ||||||
|  | 
 | ||||||
|  |         msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) | ||||||
|  | 
 | ||||||
|  |         existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( | ||||||
|  |             'msisdn', msisdn | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         if existingUid is not None: | ||||||
|  |             raise SynapseError( | ||||||
|  |                 400, "Phone number is already in use", Codes.THREEPID_IN_USE | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         ret = yield self.identity_handler.requestMsisdnToken(**body) | ||||||
|  |         defer.returnValue((200, ret)) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class RegisterRestServlet(RestServlet): | class RegisterRestServlet(RestServlet): | ||||||
|     PATTERNS = client_v2_patterns("/register$") |     PATTERNS = client_v2_patterns("/register$") | ||||||
| 
 | 
 | ||||||
| @ -200,16 +236,37 @@ class RegisterRestServlet(RestServlet): | |||||||
|                 assigned_user_id=registered_user_id, |                 assigned_user_id=registered_user_id, | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|  |         # Only give msisdn flows if the x_show_msisdn flag is given: | ||||||
|  |         # this is a hack to work around the fact that clients were shipped | ||||||
|  |         # that use fallback registration if they see any flows that they don't | ||||||
|  |         # recognise, which means we break registration for these clients if we | ||||||
|  |         # advertise msisdn flows. Once usage of Riot iOS <=0.3.9 and Riot | ||||||
|  |         # Android <=0.6.9 have fallen below an acceptable threshold, this | ||||||
|  |         # parameter should go away and we should always advertise msisdn flows. | ||||||
|  |         show_msisdn = False | ||||||
|  |         if 'x_show_msisdn' in body and body['x_show_msisdn']: | ||||||
|  |             show_msisdn = True | ||||||
|  | 
 | ||||||
|         if self.hs.config.enable_registration_captcha: |         if self.hs.config.enable_registration_captcha: | ||||||
|             flows = [ |             flows = [ | ||||||
|                 [LoginType.RECAPTCHA], |                 [LoginType.RECAPTCHA], | ||||||
|                 [LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA] |                 [LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA], | ||||||
|             ] |             ] | ||||||
|  |             if show_msisdn: | ||||||
|  |                 flows.extend([ | ||||||
|  |                     [LoginType.MSISDN, LoginType.RECAPTCHA], | ||||||
|  |                     [LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA], | ||||||
|  |                 ]) | ||||||
|         else: |         else: | ||||||
|             flows = [ |             flows = [ | ||||||
|                 [LoginType.DUMMY], |                 [LoginType.DUMMY], | ||||||
|                 [LoginType.EMAIL_IDENTITY] |                 [LoginType.EMAIL_IDENTITY], | ||||||
|             ] |             ] | ||||||
|  |             if show_msisdn: | ||||||
|  |                 flows.extend([ | ||||||
|  |                     [LoginType.MSISDN], | ||||||
|  |                     [LoginType.MSISDN, LoginType.EMAIL_IDENTITY], | ||||||
|  |                 ]) | ||||||
| 
 | 
 | ||||||
|         authed, auth_result, params, session_id = yield self.auth_handler.check_auth( |         authed, auth_result, params, session_id = yield self.auth_handler.check_auth( | ||||||
|             flows, body, self.hs.get_ip_from_request(request) |             flows, body, self.hs.get_ip_from_request(request) | ||||||
| @ -224,8 +281,9 @@ class RegisterRestServlet(RestServlet): | |||||||
|                 "Already registered user ID %r for this session", |                 "Already registered user ID %r for this session", | ||||||
|                 registered_user_id |                 registered_user_id | ||||||
|             ) |             ) | ||||||
|             # don't re-register the email address |             # don't re-register the threepids | ||||||
|             add_email = False |             add_email = False | ||||||
|  |             add_msisdn = False | ||||||
|         else: |         else: | ||||||
|             # NB: This may be from the auth handler and NOT from the POST |             # NB: This may be from the auth handler and NOT from the POST | ||||||
|             if 'password' not in params: |             if 'password' not in params: | ||||||
| @ -250,6 +308,7 @@ class RegisterRestServlet(RestServlet): | |||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|             add_email = True |             add_email = True | ||||||
|  |             add_msisdn = True | ||||||
| 
 | 
 | ||||||
|         return_dict = yield self._create_registration_details( |         return_dict = yield self._create_registration_details( | ||||||
|             registered_user_id, params |             registered_user_id, params | ||||||
| @ -262,6 +321,13 @@ class RegisterRestServlet(RestServlet): | |||||||
|                 params.get("bind_email") |                 params.get("bind_email") | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|  |         if add_msisdn and auth_result and LoginType.MSISDN in auth_result: | ||||||
|  |             threepid = auth_result[LoginType.MSISDN] | ||||||
|  |             yield self._register_msisdn_threepid( | ||||||
|  |                 registered_user_id, threepid, return_dict["access_token"], | ||||||
|  |                 params.get("bind_msisdn") | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|         defer.returnValue((200, return_dict)) |         defer.returnValue((200, return_dict)) | ||||||
| 
 | 
 | ||||||
|     def on_OPTIONS(self, _): |     def on_OPTIONS(self, _): | ||||||
| @ -323,8 +389,9 @@ class RegisterRestServlet(RestServlet): | |||||||
|         """ |         """ | ||||||
|         reqd = ('medium', 'address', 'validated_at') |         reqd = ('medium', 'address', 'validated_at') | ||||||
|         if any(x not in threepid for x in reqd): |         if any(x not in threepid for x in reqd): | ||||||
|  |             # This will only happen if the ID server returns a malformed response | ||||||
|             logger.info("Can't add incomplete 3pid") |             logger.info("Can't add incomplete 3pid") | ||||||
|             defer.returnValue() |             return | ||||||
| 
 | 
 | ||||||
|         yield self.auth_handler.add_threepid( |         yield self.auth_handler.add_threepid( | ||||||
|             user_id, |             user_id, | ||||||
| @ -371,6 +438,43 @@ class RegisterRestServlet(RestServlet): | |||||||
|         else: |         else: | ||||||
|             logger.info("bind_email not specified: not binding email") |             logger.info("bind_email not specified: not binding email") | ||||||
| 
 | 
 | ||||||
|  |     @defer.inlineCallbacks | ||||||
|  |     def _register_msisdn_threepid(self, user_id, threepid, token, bind_msisdn): | ||||||
|  |         """Add a phone number as a 3pid identifier | ||||||
|  | 
 | ||||||
|  |         Also optionally binds msisdn to the given user_id on the identity server | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             user_id (str): id of user | ||||||
|  |             threepid (object): m.login.msisdn auth response | ||||||
|  |             token (str): access_token for the user | ||||||
|  |             bind_email (bool): true if the client requested the email to be | ||||||
|  |                 bound at the identity server | ||||||
|  |         Returns: | ||||||
|  |             defer.Deferred: | ||||||
|  |         """ | ||||||
|  |         reqd = ('medium', 'address', 'validated_at') | ||||||
|  |         if any(x not in threepid for x in reqd): | ||||||
|  |             # This will only happen if the ID server returns a malformed response | ||||||
|  |             logger.info("Can't add incomplete 3pid") | ||||||
|  |             defer.returnValue() | ||||||
|  | 
 | ||||||
|  |         yield self.auth_handler.add_threepid( | ||||||
|  |             user_id, | ||||||
|  |             threepid['medium'], | ||||||
|  |             threepid['address'], | ||||||
|  |             threepid['validated_at'], | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         if bind_msisdn: | ||||||
|  |             logger.info("bind_msisdn specified: binding") | ||||||
|  |             logger.debug("Binding msisdn %s to %s", threepid, user_id) | ||||||
|  |             yield self.identity_handler.bind_threepid( | ||||||
|  |                 threepid['threepid_creds'], user_id | ||||||
|  |             ) | ||||||
|  |         else: | ||||||
|  |             logger.info("bind_msisdn not specified: not binding msisdn") | ||||||
|  | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def _create_registration_details(self, user_id, params): |     def _create_registration_details(self, user_id, params): | ||||||
|         """Complete registration of newly-registered user |         """Complete registration of newly-registered user | ||||||
| @ -433,7 +537,7 @@ class RegisterRestServlet(RestServlet): | |||||||
|         # we have nowhere to store it. |         # we have nowhere to store it. | ||||||
|         device_id = synapse.api.auth.GUEST_DEVICE_ID |         device_id = synapse.api.auth.GUEST_DEVICE_ID | ||||||
|         initial_display_name = params.get("initial_device_display_name") |         initial_display_name = params.get("initial_device_display_name") | ||||||
|         self.device_handler.check_device_registered( |         yield self.device_handler.check_device_registered( | ||||||
|             user_id, device_id, initial_display_name |             user_id, device_id, initial_display_name | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
| @ -449,5 +553,6 @@ class RegisterRestServlet(RestServlet): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def register_servlets(hs, http_server): | def register_servlets(hs, http_server): | ||||||
|     RegisterRequestTokenRestServlet(hs).register(http_server) |     EmailRegisterRequestTokenRestServlet(hs).register(http_server) | ||||||
|  |     MsisdnRegisterRequestTokenRestServlet(hs).register(http_server) | ||||||
|     RegisterRestServlet(hs).register(http_server) |     RegisterRestServlet(hs).register(http_server) | ||||||
|  | |||||||
| @ -18,6 +18,7 @@ from twisted.internet import defer | |||||||
| from synapse.http.servlet import ( | from synapse.http.servlet import ( | ||||||
|     RestServlet, parse_string, parse_integer, parse_boolean |     RestServlet, parse_string, parse_integer, parse_boolean | ||||||
| ) | ) | ||||||
|  | from synapse.handlers.presence import format_user_presence_state | ||||||
| from synapse.handlers.sync import SyncConfig | from synapse.handlers.sync import SyncConfig | ||||||
| from synapse.types import StreamToken | from synapse.types import StreamToken | ||||||
| from synapse.events.utils import ( | from synapse.events.utils import ( | ||||||
| @ -28,7 +29,6 @@ from synapse.api.errors import SynapseError | |||||||
| from synapse.api.constants import PresenceState | from synapse.api.constants import PresenceState | ||||||
| from ._base import client_v2_patterns | from ._base import client_v2_patterns | ||||||
| 
 | 
 | ||||||
| import copy |  | ||||||
| import itertools | import itertools | ||||||
| import logging | import logging | ||||||
| 
 | 
 | ||||||
| @ -194,12 +194,18 @@ class SyncRestServlet(RestServlet): | |||||||
|         defer.returnValue((200, response_content)) |         defer.returnValue((200, response_content)) | ||||||
| 
 | 
 | ||||||
|     def encode_presence(self, events, time_now): |     def encode_presence(self, events, time_now): | ||||||
|         formatted = [] |         return { | ||||||
|         for event in events: |             "events": [ | ||||||
|             event = copy.deepcopy(event) |                 { | ||||||
|             event['sender'] = event['content'].pop('user_id') |                     "type": "m.presence", | ||||||
|             formatted.append(event) |                     "sender": event.user_id, | ||||||
|         return {"events": formatted} |                     "content": format_user_presence_state( | ||||||
|  |                         event, time_now, include_user_id=False | ||||||
|  |                     ), | ||||||
|  |                 } | ||||||
|  |                 for event in events | ||||||
|  |             ] | ||||||
|  |         } | ||||||
| 
 | 
 | ||||||
|     def encode_joined(self, rooms, time_now, token_id, event_fields): |     def encode_joined(self, rooms, time_now, token_id, event_fields): | ||||||
|         """ |         """ | ||||||
|  | |||||||
| @ -12,6 +12,7 @@ | |||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
|  | import synapse.http.servlet | ||||||
| 
 | 
 | ||||||
| from ._base import parse_media_id, respond_with_file, respond_404 | from ._base import parse_media_id, respond_with_file, respond_404 | ||||||
| from twisted.web.resource import Resource | from twisted.web.resource import Resource | ||||||
| @ -81,6 +82,17 @@ class DownloadResource(Resource): | |||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def _respond_remote_file(self, request, server_name, media_id, name): |     def _respond_remote_file(self, request, server_name, media_id, name): | ||||||
|  |         # don't forward requests for remote media if allow_remote is false | ||||||
|  |         allow_remote = synapse.http.servlet.parse_boolean( | ||||||
|  |             request, "allow_remote", default=True) | ||||||
|  |         if not allow_remote: | ||||||
|  |             logger.info( | ||||||
|  |                 "Rejecting request for remote media %s/%s due to allow_remote", | ||||||
|  |                 server_name, media_id, | ||||||
|  |             ) | ||||||
|  |             respond_404(request) | ||||||
|  |             return | ||||||
|  | 
 | ||||||
|         media_info = yield self.media_repo.get_remote_media(server_name, media_id) |         media_info = yield self.media_repo.get_remote_media(server_name, media_id) | ||||||
| 
 | 
 | ||||||
|         media_type = media_info["media_type"] |         media_type = media_info["media_type"] | ||||||
|  | |||||||
| @ -13,22 +13,23 @@ | |||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| 
 | 
 | ||||||
|  | from twisted.internet import defer, threads | ||||||
|  | import twisted.internet.error | ||||||
|  | import twisted.web.http | ||||||
|  | from twisted.web.resource import Resource | ||||||
|  | 
 | ||||||
| from .upload_resource import UploadResource | from .upload_resource import UploadResource | ||||||
| from .download_resource import DownloadResource | from .download_resource import DownloadResource | ||||||
| from .thumbnail_resource import ThumbnailResource | from .thumbnail_resource import ThumbnailResource | ||||||
| from .identicon_resource import IdenticonResource | from .identicon_resource import IdenticonResource | ||||||
| from .preview_url_resource import PreviewUrlResource | from .preview_url_resource import PreviewUrlResource | ||||||
| from .filepath import MediaFilePaths | from .filepath import MediaFilePaths | ||||||
| 
 |  | ||||||
| from twisted.web.resource import Resource |  | ||||||
| 
 |  | ||||||
| from .thumbnailer import Thumbnailer | from .thumbnailer import Thumbnailer | ||||||
| 
 | 
 | ||||||
| from synapse.http.matrixfederationclient import MatrixFederationHttpClient | from synapse.http.matrixfederationclient import MatrixFederationHttpClient | ||||||
| from synapse.util.stringutils import random_string | from synapse.util.stringutils import random_string | ||||||
| from synapse.api.errors import SynapseError | from synapse.api.errors import SynapseError, HttpResponseException, \ | ||||||
| 
 |     NotFoundError | ||||||
| from twisted.internet import defer, threads |  | ||||||
| 
 | 
 | ||||||
| from synapse.util.async import Linearizer | from synapse.util.async import Linearizer | ||||||
| from synapse.util.stringutils import is_ascii | from synapse.util.stringutils import is_ascii | ||||||
| @ -157,11 +158,34 @@ class MediaRepository(object): | |||||||
|                 try: |                 try: | ||||||
|                     length, headers = yield self.client.get_file( |                     length, headers = yield self.client.get_file( | ||||||
|                         server_name, request_path, output_stream=f, |                         server_name, request_path, output_stream=f, | ||||||
|                         max_size=self.max_upload_size, |                         max_size=self.max_upload_size, args={ | ||||||
|  |                             # tell the remote server to 404 if it doesn't | ||||||
|  |                             # recognise the server_name, to make sure we don't | ||||||
|  |                             # end up with a routing loop. | ||||||
|  |                             "allow_remote": "false", | ||||||
|  |                         } | ||||||
|                     ) |                     ) | ||||||
|                 except Exception as e: |                 except twisted.internet.error.DNSLookupError as e: | ||||||
|                     logger.warn("Failed to fetch remoted media %r", e) |                     logger.warn("HTTP error fetching remote media %s/%s: %r", | ||||||
|                     raise SynapseError(502, "Failed to fetch remoted media") |                                 server_name, media_id, e) | ||||||
|  |                     raise NotFoundError() | ||||||
|  | 
 | ||||||
|  |                 except HttpResponseException as e: | ||||||
|  |                     logger.warn("HTTP error fetching remote media %s/%s: %s", | ||||||
|  |                                 server_name, media_id, e.response) | ||||||
|  |                     if e.code == twisted.web.http.NOT_FOUND: | ||||||
|  |                         raise SynapseError.from_http_response_exception(e) | ||||||
|  |                     raise SynapseError(502, "Failed to fetch remote media") | ||||||
|  | 
 | ||||||
|  |                 except SynapseError: | ||||||
|  |                     logger.exception("Failed to fetch remote media %s/%s", | ||||||
|  |                                      server_name, media_id) | ||||||
|  |                     raise | ||||||
|  | 
 | ||||||
|  |                 except Exception: | ||||||
|  |                     logger.exception("Failed to fetch remote media %s/%s", | ||||||
|  |                                      server_name, media_id) | ||||||
|  |                     raise SynapseError(502, "Failed to fetch remote media") | ||||||
| 
 | 
 | ||||||
|             media_type = headers["Content-Type"][0] |             media_type = headers["Content-Type"][0] | ||||||
|             time_now_ms = self.clock.time_msec() |             time_now_ms = self.clock.time_msec() | ||||||
|  | |||||||
| @ -177,17 +177,12 @@ class StateHandler(object): | |||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def compute_event_context(self, event, old_state=None): |     def compute_event_context(self, event, old_state=None): | ||||||
|         """ Fills out the context with the `current state` of the graph. The |         """Build an EventContext structure for the event. | ||||||
|         `current state` here is defined to be the state of the event graph |  | ||||||
|         just before the event - i.e. it never includes `event` |  | ||||||
| 
 |  | ||||||
|         If `event` has `auth_events` then this will also fill out the |  | ||||||
|         `auth_events` field on `context` from the `current_state`. |  | ||||||
| 
 | 
 | ||||||
|         Args: |         Args: | ||||||
|             event (EventBase) |             event (synapse.events.EventBase): | ||||||
|         Returns: |         Returns: | ||||||
|             an EventContext |             synapse.events.snapshot.EventContext: | ||||||
|         """ |         """ | ||||||
|         context = EventContext() |         context = EventContext() | ||||||
| 
 | 
 | ||||||
| @ -200,11 +195,11 @@ class StateHandler(object): | |||||||
|                     (s.type, s.state_key): s.event_id for s in old_state |                     (s.type, s.state_key): s.event_id for s in old_state | ||||||
|                 } |                 } | ||||||
|                 if event.is_state(): |                 if event.is_state(): | ||||||
|                     context.current_state_events = dict(context.prev_state_ids) |                     context.current_state_ids = dict(context.prev_state_ids) | ||||||
|                     key = (event.type, event.state_key) |                     key = (event.type, event.state_key) | ||||||
|                     context.current_state_events[key] = event.event_id |                     context.current_state_ids[key] = event.event_id | ||||||
|                 else: |                 else: | ||||||
|                     context.current_state_events = context.prev_state_ids |                     context.current_state_ids = context.prev_state_ids | ||||||
|             else: |             else: | ||||||
|                 context.current_state_ids = {} |                 context.current_state_ids = {} | ||||||
|                 context.prev_state_ids = {} |                 context.prev_state_ids = {} | ||||||
|  | |||||||
| @ -73,6 +73,9 @@ class LoggingTransaction(object): | |||||||
|     def __setattr__(self, name, value): |     def __setattr__(self, name, value): | ||||||
|         setattr(self.txn, name, value) |         setattr(self.txn, name, value) | ||||||
| 
 | 
 | ||||||
|  |     def __iter__(self): | ||||||
|  |         return self.txn.__iter__() | ||||||
|  | 
 | ||||||
|     def execute(self, sql, *args): |     def execute(self, sql, *args): | ||||||
|         self._do_execute(self.txn.execute, sql, *args) |         self._do_execute(self.txn.execute, sql, *args) | ||||||
| 
 | 
 | ||||||
| @ -132,7 +135,7 @@ class PerformanceCounters(object): | |||||||
| 
 | 
 | ||||||
|     def interval(self, interval_duration, limit=3): |     def interval(self, interval_duration, limit=3): | ||||||
|         counters = [] |         counters = [] | ||||||
|         for name, (count, cum_time) in self.current_counters.items(): |         for name, (count, cum_time) in self.current_counters.iteritems(): | ||||||
|             prev_count, prev_time = self.previous_counters.get(name, (0, 0)) |             prev_count, prev_time = self.previous_counters.get(name, (0, 0)) | ||||||
|             counters.append(( |             counters.append(( | ||||||
|                 (cum_time - prev_time) / interval_duration, |                 (cum_time - prev_time) / interval_duration, | ||||||
| @ -357,7 +360,7 @@ class SQLBaseStore(object): | |||||||
|         """ |         """ | ||||||
|         col_headers = list(intern(column[0]) for column in cursor.description) |         col_headers = list(intern(column[0]) for column in cursor.description) | ||||||
|         results = list( |         results = list( | ||||||
|             dict(zip(col_headers, row)) for row in cursor.fetchall() |             dict(zip(col_headers, row)) for row in cursor | ||||||
|         ) |         ) | ||||||
|         return results |         return results | ||||||
| 
 | 
 | ||||||
| @ -565,7 +568,7 @@ class SQLBaseStore(object): | |||||||
|     @staticmethod |     @staticmethod | ||||||
|     def _simple_select_onecol_txn(txn, table, keyvalues, retcol): |     def _simple_select_onecol_txn(txn, table, keyvalues, retcol): | ||||||
|         if keyvalues: |         if keyvalues: | ||||||
|             where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) |             where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys()) | ||||||
|         else: |         else: | ||||||
|             where = "" |             where = "" | ||||||
| 
 | 
 | ||||||
| @ -579,7 +582,7 @@ class SQLBaseStore(object): | |||||||
| 
 | 
 | ||||||
|         txn.execute(sql, keyvalues.values()) |         txn.execute(sql, keyvalues.values()) | ||||||
| 
 | 
 | ||||||
|         return [r[0] for r in txn.fetchall()] |         return [r[0] for r in txn] | ||||||
| 
 | 
 | ||||||
|     def _simple_select_onecol(self, table, keyvalues, retcol, |     def _simple_select_onecol(self, table, keyvalues, retcol, | ||||||
|                               desc="_simple_select_onecol"): |                               desc="_simple_select_onecol"): | ||||||
| @ -712,7 +715,7 @@ class SQLBaseStore(object): | |||||||
|         ) |         ) | ||||||
|         values.extend(iterable) |         values.extend(iterable) | ||||||
| 
 | 
 | ||||||
|         for key, value in keyvalues.items(): |         for key, value in keyvalues.iteritems(): | ||||||
|             clauses.append("%s = ?" % (key,)) |             clauses.append("%s = ?" % (key,)) | ||||||
|             values.append(value) |             values.append(value) | ||||||
| 
 | 
 | ||||||
| @ -753,7 +756,7 @@ class SQLBaseStore(object): | |||||||
|     @staticmethod |     @staticmethod | ||||||
|     def _simple_update_one_txn(txn, table, keyvalues, updatevalues): |     def _simple_update_one_txn(txn, table, keyvalues, updatevalues): | ||||||
|         if keyvalues: |         if keyvalues: | ||||||
|             where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) |             where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys()) | ||||||
|         else: |         else: | ||||||
|             where = "" |             where = "" | ||||||
| 
 | 
 | ||||||
| @ -840,6 +843,47 @@ class SQLBaseStore(object): | |||||||
| 
 | 
 | ||||||
|         return txn.execute(sql, keyvalues.values()) |         return txn.execute(sql, keyvalues.values()) | ||||||
| 
 | 
 | ||||||
|  |     def _simple_delete_many(self, table, column, iterable, keyvalues, desc): | ||||||
|  |         return self.runInteraction( | ||||||
|  |             desc, self._simple_delete_many_txn, table, column, iterable, keyvalues | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def _simple_delete_many_txn(txn, table, column, iterable, keyvalues): | ||||||
|  |         """Executes a DELETE query on the named table. | ||||||
|  | 
 | ||||||
|  |         Filters rows by if value of `column` is in `iterable`. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             txn : Transaction object | ||||||
|  |             table : string giving the table name | ||||||
|  |             column : column name to test for inclusion against `iterable` | ||||||
|  |             iterable : list | ||||||
|  |             keyvalues : dict of column names and values to select the rows with | ||||||
|  |         """ | ||||||
|  |         if not iterable: | ||||||
|  |             return | ||||||
|  | 
 | ||||||
|  |         sql = "DELETE FROM %s" % table | ||||||
|  | 
 | ||||||
|  |         clauses = [] | ||||||
|  |         values = [] | ||||||
|  |         clauses.append( | ||||||
|  |             "%s IN (%s)" % (column, ",".join("?" for _ in iterable)) | ||||||
|  |         ) | ||||||
|  |         values.extend(iterable) | ||||||
|  | 
 | ||||||
|  |         for key, value in keyvalues.iteritems(): | ||||||
|  |             clauses.append("%s = ?" % (key,)) | ||||||
|  |             values.append(value) | ||||||
|  | 
 | ||||||
|  |         if clauses: | ||||||
|  |             sql = "%s WHERE %s" % ( | ||||||
|  |                 sql, | ||||||
|  |                 " AND ".join(clauses), | ||||||
|  |             ) | ||||||
|  |         return txn.execute(sql, values) | ||||||
|  | 
 | ||||||
|     def _get_cache_dict(self, db_conn, table, entity_column, stream_column, |     def _get_cache_dict(self, db_conn, table, entity_column, stream_column, | ||||||
|                         max_value, limit=100000): |                         max_value, limit=100000): | ||||||
|         # Fetch a mapping of room_id -> max stream position for "recent" rooms. |         # Fetch a mapping of room_id -> max stream position for "recent" rooms. | ||||||
| @ -860,16 +904,16 @@ class SQLBaseStore(object): | |||||||
| 
 | 
 | ||||||
|         txn = db_conn.cursor() |         txn = db_conn.cursor() | ||||||
|         txn.execute(sql, (int(max_value),)) |         txn.execute(sql, (int(max_value),)) | ||||||
|         rows = txn.fetchall() |  | ||||||
|         txn.close() |  | ||||||
| 
 | 
 | ||||||
|         cache = { |         cache = { | ||||||
|             row[0]: int(row[1]) |             row[0]: int(row[1]) | ||||||
|             for row in rows |             for row in txn | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|  |         txn.close() | ||||||
|  | 
 | ||||||
|         if cache: |         if cache: | ||||||
|             min_val = min(cache.values()) |             min_val = min(cache.itervalues()) | ||||||
|         else: |         else: | ||||||
|             min_val = max_value |             min_val = max_value | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -182,7 +182,7 @@ class AccountDataStore(SQLBaseStore): | |||||||
|             txn.execute(sql, (user_id, stream_id)) |             txn.execute(sql, (user_id, stream_id)) | ||||||
| 
 | 
 | ||||||
|             global_account_data = { |             global_account_data = { | ||||||
|                 row[0]: json.loads(row[1]) for row in txn.fetchall() |                 row[0]: json.loads(row[1]) for row in txn | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             sql = ( |             sql = ( | ||||||
| @ -193,7 +193,7 @@ class AccountDataStore(SQLBaseStore): | |||||||
|             txn.execute(sql, (user_id, stream_id)) |             txn.execute(sql, (user_id, stream_id)) | ||||||
| 
 | 
 | ||||||
|             account_data_by_room = {} |             account_data_by_room = {} | ||||||
|             for row in txn.fetchall(): |             for row in txn: | ||||||
|                 room_account_data = account_data_by_room.setdefault(row[0], {}) |                 room_account_data = account_data_by_room.setdefault(row[0], {}) | ||||||
|                 room_account_data[row[1]] = json.loads(row[2]) |                 room_account_data[row[1]] = json.loads(row[2]) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -12,6 +12,7 @@ | |||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
|  | import synapse.util.async | ||||||
| 
 | 
 | ||||||
| from ._base import SQLBaseStore | from ._base import SQLBaseStore | ||||||
| from . import engines | from . import engines | ||||||
| @ -84,24 +85,14 @@ class BackgroundUpdateStore(SQLBaseStore): | |||||||
|         self._background_update_performance = {} |         self._background_update_performance = {} | ||||||
|         self._background_update_queue = [] |         self._background_update_queue = [] | ||||||
|         self._background_update_handlers = {} |         self._background_update_handlers = {} | ||||||
|         self._background_update_timer = None |  | ||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def start_doing_background_updates(self): |     def start_doing_background_updates(self): | ||||||
|         assert self._background_update_timer is None, \ |  | ||||||
|             "background updates already running" |  | ||||||
| 
 |  | ||||||
|         logger.info("Starting background schema updates") |         logger.info("Starting background schema updates") | ||||||
| 
 | 
 | ||||||
|         while True: |         while True: | ||||||
|             sleep = defer.Deferred() |             yield synapse.util.async.sleep( | ||||||
|             self._background_update_timer = self._clock.call_later( |                 self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.) | ||||||
|                 self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None |  | ||||||
|             ) |  | ||||||
|             try: |  | ||||||
|                 yield sleep |  | ||||||
|             finally: |  | ||||||
|                 self._background_update_timer = None |  | ||||||
| 
 | 
 | ||||||
|             try: |             try: | ||||||
|                 result = yield self.do_next_background_update( |                 result = yield self.do_next_background_update( | ||||||
|  | |||||||
| @ -178,7 +178,7 @@ class DeviceInboxStore(BackgroundUpdateStore): | |||||||
|                 ) |                 ) | ||||||
|                 txn.execute(sql, (user_id,)) |                 txn.execute(sql, (user_id,)) | ||||||
|                 message_json = ujson.dumps(messages_by_device["*"]) |                 message_json = ujson.dumps(messages_by_device["*"]) | ||||||
|                 for row in txn.fetchall(): |                 for row in txn: | ||||||
|                     # Add the message for all devices for this user on this |                     # Add the message for all devices for this user on this | ||||||
|                     # server. |                     # server. | ||||||
|                     device = row[0] |                     device = row[0] | ||||||
| @ -195,7 +195,7 @@ class DeviceInboxStore(BackgroundUpdateStore): | |||||||
|                 # TODO: Maybe this needs to be done in batches if there are |                 # TODO: Maybe this needs to be done in batches if there are | ||||||
|                 # too many local devices for a given user. |                 # too many local devices for a given user. | ||||||
|                 txn.execute(sql, [user_id] + devices) |                 txn.execute(sql, [user_id] + devices) | ||||||
|                 for row in txn.fetchall(): |                 for row in txn: | ||||||
|                     # Only insert into the local inbox if the device exists on |                     # Only insert into the local inbox if the device exists on | ||||||
|                     # this server |                     # this server | ||||||
|                     device = row[0] |                     device = row[0] | ||||||
| @ -251,7 +251,7 @@ class DeviceInboxStore(BackgroundUpdateStore): | |||||||
|                 user_id, device_id, last_stream_id, current_stream_id, limit |                 user_id, device_id, last_stream_id, current_stream_id, limit | ||||||
|             )) |             )) | ||||||
|             messages = [] |             messages = [] | ||||||
|             for row in txn.fetchall(): |             for row in txn: | ||||||
|                 stream_pos = row[0] |                 stream_pos = row[0] | ||||||
|                 messages.append(ujson.loads(row[1])) |                 messages.append(ujson.loads(row[1])) | ||||||
|             if len(messages) < limit: |             if len(messages) < limit: | ||||||
| @ -340,7 +340,7 @@ class DeviceInboxStore(BackgroundUpdateStore): | |||||||
|                 " ORDER BY stream_id ASC" |                 " ORDER BY stream_id ASC" | ||||||
|             ) |             ) | ||||||
|             txn.execute(sql, (last_pos, upper_pos)) |             txn.execute(sql, (last_pos, upper_pos)) | ||||||
|             rows.extend(txn.fetchall()) |             rows.extend(txn) | ||||||
| 
 | 
 | ||||||
|             return rows |             return rows | ||||||
| 
 | 
 | ||||||
| @ -357,12 +357,12 @@ class DeviceInboxStore(BackgroundUpdateStore): | |||||||
|         """ |         """ | ||||||
|         Args: |         Args: | ||||||
|             destination(str): The name of the remote server. |             destination(str): The name of the remote server. | ||||||
|             last_stream_id(int): The last position of the device message stream |             last_stream_id(int|long): The last position of the device message stream | ||||||
|                 that the server sent up to. |                 that the server sent up to. | ||||||
|             current_stream_id(int): The current position of the device |             current_stream_id(int|long): The current position of the device | ||||||
|                 message stream. |                 message stream. | ||||||
|         Returns: |         Returns: | ||||||
|             Deferred ([dict], int): List of messages for the device and where |             Deferred ([dict], int|long): List of messages for the device and where | ||||||
|                 in the stream the messages got to. |                 in the stream the messages got to. | ||||||
|         """ |         """ | ||||||
| 
 | 
 | ||||||
| @ -384,7 +384,7 @@ class DeviceInboxStore(BackgroundUpdateStore): | |||||||
|                 destination, last_stream_id, current_stream_id, limit |                 destination, last_stream_id, current_stream_id, limit | ||||||
|             )) |             )) | ||||||
|             messages = [] |             messages = [] | ||||||
|             for row in txn.fetchall(): |             for row in txn: | ||||||
|                 stream_pos = row[0] |                 stream_pos = row[0] | ||||||
|                 messages.append(ujson.loads(row[1])) |                 messages.append(ujson.loads(row[1])) | ||||||
|             if len(messages) < limit: |             if len(messages) < limit: | ||||||
|  | |||||||
| @ -108,6 +108,23 @@ class DeviceStore(SQLBaseStore): | |||||||
|             desc="delete_device", |             desc="delete_device", | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  |     def delete_devices(self, user_id, device_ids): | ||||||
|  |         """Deletes several devices. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             user_id (str): The ID of the user which owns the devices | ||||||
|  |             device_ids (list): The IDs of the devices to delete | ||||||
|  |         Returns: | ||||||
|  |             defer.Deferred | ||||||
|  |         """ | ||||||
|  |         return self._simple_delete_many( | ||||||
|  |             table="devices", | ||||||
|  |             column="device_id", | ||||||
|  |             iterable=device_ids, | ||||||
|  |             keyvalues={"user_id": user_id}, | ||||||
|  |             desc="delete_devices", | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|     def update_device(self, user_id, device_id, new_display_name=None): |     def update_device(self, user_id, device_id, new_display_name=None): | ||||||
|         """Update a device. |         """Update a device. | ||||||
| 
 | 
 | ||||||
| @ -291,7 +308,7 @@ class DeviceStore(SQLBaseStore): | |||||||
|         """Get stream of updates to send to remote servers |         """Get stream of updates to send to remote servers | ||||||
| 
 | 
 | ||||||
|         Returns: |         Returns: | ||||||
|             (now_stream_id, [ { updates }, .. ]) |             (int, list[dict]): current stream id and list of updates | ||||||
|         """ |         """ | ||||||
|         now_stream_id = self._device_list_id_gen.get_current_token() |         now_stream_id = self._device_list_id_gen.get_current_token() | ||||||
| 
 | 
 | ||||||
| @ -312,17 +329,20 @@ class DeviceStore(SQLBaseStore): | |||||||
|             SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes |             SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes | ||||||
|             WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? |             WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? | ||||||
|             GROUP BY user_id, device_id |             GROUP BY user_id, device_id | ||||||
|  |             LIMIT 20 | ||||||
|         """ |         """ | ||||||
|         txn.execute( |         txn.execute( | ||||||
|             sql, (destination, from_stream_id, now_stream_id, False) |             sql, (destination, from_stream_id, now_stream_id, False) | ||||||
|         ) |         ) | ||||||
|         rows = txn.fetchall() |  | ||||||
| 
 |  | ||||||
|         if not rows: |  | ||||||
|             return (now_stream_id, []) |  | ||||||
| 
 | 
 | ||||||
|         # maps (user_id, device_id) -> stream_id |         # maps (user_id, device_id) -> stream_id | ||||||
|         query_map = {(r[0], r[1]): r[2] for r in rows} |         query_map = {(r[0], r[1]): r[2] for r in txn} | ||||||
|  |         if not query_map: | ||||||
|  |             return (now_stream_id, []) | ||||||
|  | 
 | ||||||
|  |         if len(query_map) >= 20: | ||||||
|  |             now_stream_id = max(stream_id for stream_id in query_map.itervalues()) | ||||||
|  | 
 | ||||||
|         devices = self._get_e2e_device_keys_txn( |         devices = self._get_e2e_device_keys_txn( | ||||||
|             txn, query_map.keys(), include_all_devices=True |             txn, query_map.keys(), include_all_devices=True | ||||||
|         ) |         ) | ||||||
|  | |||||||
| @ -14,6 +14,8 @@ | |||||||
| # limitations under the License. | # limitations under the License. | ||||||
| from twisted.internet import defer | from twisted.internet import defer | ||||||
| 
 | 
 | ||||||
|  | from synapse.api.errors import SynapseError | ||||||
|  | 
 | ||||||
| from canonicaljson import encode_canonical_json | from canonicaljson import encode_canonical_json | ||||||
| import ujson as json | import ujson as json | ||||||
| 
 | 
 | ||||||
| @ -120,24 +122,63 @@ class EndToEndKeyStore(SQLBaseStore): | |||||||
| 
 | 
 | ||||||
|         return result |         return result | ||||||
| 
 | 
 | ||||||
|  |     @defer.inlineCallbacks | ||||||
|     def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list): |     def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list): | ||||||
|         def _add_e2e_one_time_keys(txn): |         """Insert some new one time keys for a device. | ||||||
|             for (algorithm, key_id, json_bytes) in key_list: | 
 | ||||||
|                 self._simple_upsert_txn( |         Checks if any of the keys are already inserted, if they are then check | ||||||
|                     txn, table="e2e_one_time_keys_json", |         if they match. If they don't then we raise an error. | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |         # First we check if we have already persisted any of the keys. | ||||||
|  |         rows = yield self._simple_select_many_batch( | ||||||
|  |             table="e2e_one_time_keys_json", | ||||||
|  |             column="key_id", | ||||||
|  |             iterable=[key_id for _, key_id, _ in key_list], | ||||||
|  |             retcols=("algorithm", "key_id", "key_json",), | ||||||
|             keyvalues={ |             keyvalues={ | ||||||
|                 "user_id": user_id, |                 "user_id": user_id, | ||||||
|                 "device_id": device_id, |                 "device_id": device_id, | ||||||
|  |             }, | ||||||
|  |             desc="add_e2e_one_time_keys_check", | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         existing_key_map = { | ||||||
|  |             (row["algorithm"], row["key_id"]): row["key_json"] for row in rows | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         new_keys = []  # Keys that we need to insert | ||||||
|  |         for algorithm, key_id, json_bytes in key_list: | ||||||
|  |             ex_bytes = existing_key_map.get((algorithm, key_id), None) | ||||||
|  |             if ex_bytes: | ||||||
|  |                 if json_bytes != ex_bytes: | ||||||
|  |                     raise SynapseError( | ||||||
|  |                         400, "One time key with key_id %r already exists" % (key_id,) | ||||||
|  |                     ) | ||||||
|  |             else: | ||||||
|  |                 new_keys.append((algorithm, key_id, json_bytes)) | ||||||
|  | 
 | ||||||
|  |         def _add_e2e_one_time_keys(txn): | ||||||
|  |             # We are protected from race between lookup and insertion due to | ||||||
|  |             # a unique constraint. If there is a race of two calls to | ||||||
|  |             # `add_e2e_one_time_keys` then they'll conflict and we will only | ||||||
|  |             # insert one set. | ||||||
|  |             self._simple_insert_many_txn( | ||||||
|  |                 txn, table="e2e_one_time_keys_json", | ||||||
|  |                 values=[ | ||||||
|  |                     { | ||||||
|  |                         "user_id": user_id, | ||||||
|  |                         "device_id": device_id, | ||||||
|                         "algorithm": algorithm, |                         "algorithm": algorithm, | ||||||
|                         "key_id": key_id, |                         "key_id": key_id, | ||||||
|                     }, |  | ||||||
|                     values={ |  | ||||||
|                         "ts_added_ms": time_now, |                         "ts_added_ms": time_now, | ||||||
|                         "key_json": json_bytes, |                         "key_json": json_bytes, | ||||||
|                     } |                     } | ||||||
|  |                     for algorithm, key_id, json_bytes in new_keys | ||||||
|  |                 ], | ||||||
|             ) |             ) | ||||||
|         return self.runInteraction( |         yield self.runInteraction( | ||||||
|             "add_e2e_one_time_keys", _add_e2e_one_time_keys |             "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     def count_e2e_one_time_keys(self, user_id, device_id): |     def count_e2e_one_time_keys(self, user_id, device_id): | ||||||
| @ -153,7 +194,7 @@ class EndToEndKeyStore(SQLBaseStore): | |||||||
|             ) |             ) | ||||||
|             txn.execute(sql, (user_id, device_id)) |             txn.execute(sql, (user_id, device_id)) | ||||||
|             result = {} |             result = {} | ||||||
|             for algorithm, key_count in txn.fetchall(): |             for algorithm, key_count in txn: | ||||||
|                 result[algorithm] = key_count |                 result[algorithm] = key_count | ||||||
|             return result |             return result | ||||||
|         return self.runInteraction( |         return self.runInteraction( | ||||||
| @ -174,7 +215,7 @@ class EndToEndKeyStore(SQLBaseStore): | |||||||
|                 user_result = result.setdefault(user_id, {}) |                 user_result = result.setdefault(user_id, {}) | ||||||
|                 device_result = user_result.setdefault(device_id, {}) |                 device_result = user_result.setdefault(device_id, {}) | ||||||
|                 txn.execute(sql, (user_id, device_id, algorithm)) |                 txn.execute(sql, (user_id, device_id, algorithm)) | ||||||
|                 for key_id, key_json in txn.fetchall(): |                 for key_id, key_json in txn: | ||||||
|                     device_result[algorithm + ":" + key_id] = key_json |                     device_result[algorithm + ":" + key_id] = key_json | ||||||
|                     delete.append((user_id, device_id, algorithm, key_id)) |                     delete.append((user_id, device_id, algorithm, key_id)) | ||||||
|             sql = ( |             sql = ( | ||||||
|  | |||||||
| @ -74,7 +74,7 @@ class EventFederationStore(SQLBaseStore): | |||||||
|                     base_sql % (",".join(["?"] * len(chunk)),), |                     base_sql % (",".join(["?"] * len(chunk)),), | ||||||
|                     chunk |                     chunk | ||||||
|                 ) |                 ) | ||||||
|                 new_front.update([r[0] for r in txn.fetchall()]) |                 new_front.update([r[0] for r in txn]) | ||||||
| 
 | 
 | ||||||
|             new_front -= results |             new_front -= results | ||||||
| 
 | 
 | ||||||
| @ -110,7 +110,7 @@ class EventFederationStore(SQLBaseStore): | |||||||
| 
 | 
 | ||||||
|         txn.execute(sql, (room_id, False,)) |         txn.execute(sql, (room_id, False,)) | ||||||
| 
 | 
 | ||||||
|         return dict(txn.fetchall()) |         return dict(txn) | ||||||
| 
 | 
 | ||||||
|     def _get_oldest_events_in_room_txn(self, txn, room_id): |     def _get_oldest_events_in_room_txn(self, txn, room_id): | ||||||
|         return self._simple_select_onecol_txn( |         return self._simple_select_onecol_txn( | ||||||
| @ -201,9 +201,9 @@ class EventFederationStore(SQLBaseStore): | |||||||
|     def _update_min_depth_for_room_txn(self, txn, room_id, depth): |     def _update_min_depth_for_room_txn(self, txn, room_id, depth): | ||||||
|         min_depth = self._get_min_depth_interaction(txn, room_id) |         min_depth = self._get_min_depth_interaction(txn, room_id) | ||||||
| 
 | 
 | ||||||
|         do_insert = depth < min_depth if min_depth else True |         if min_depth and depth >= min_depth: | ||||||
|  |             return | ||||||
| 
 | 
 | ||||||
|         if do_insert: |  | ||||||
|         self._simple_upsert_txn( |         self._simple_upsert_txn( | ||||||
|             txn, |             txn, | ||||||
|             table="room_depth", |             table="room_depth", | ||||||
| @ -334,8 +334,7 @@ class EventFederationStore(SQLBaseStore): | |||||||
| 
 | 
 | ||||||
|         def get_forward_extremeties_for_room_txn(txn): |         def get_forward_extremeties_for_room_txn(txn): | ||||||
|             txn.execute(sql, (stream_ordering, room_id)) |             txn.execute(sql, (stream_ordering, room_id)) | ||||||
|             rows = txn.fetchall() |             return [event_id for event_id, in txn] | ||||||
|             return [event_id for event_id, in rows] |  | ||||||
| 
 | 
 | ||||||
|         return self.runInteraction( |         return self.runInteraction( | ||||||
|             "get_forward_extremeties_for_room", |             "get_forward_extremeties_for_room", | ||||||
| @ -436,7 +435,7 @@ class EventFederationStore(SQLBaseStore): | |||||||
|                 (room_id, event_id, False, limit - len(event_results)) |                 (room_id, event_id, False, limit - len(event_results)) | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|             for row in txn.fetchall(): |             for row in txn: | ||||||
|                 if row[1] not in event_results: |                 if row[1] not in event_results: | ||||||
|                     queue.put((-row[0], row[1])) |                     queue.put((-row[0], row[1])) | ||||||
| 
 | 
 | ||||||
| @ -482,7 +481,7 @@ class EventFederationStore(SQLBaseStore): | |||||||
|                     (room_id, event_id, False, limit - len(event_results)) |                     (room_id, event_id, False, limit - len(event_results)) | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|                 for e_id, in txn.fetchall(): |                 for e_id, in txn: | ||||||
|                     new_front.add(e_id) |                     new_front.add(e_id) | ||||||
| 
 | 
 | ||||||
|             new_front -= earliest_events |             new_front -= earliest_events | ||||||
|  | |||||||
| @ -206,7 +206,7 @@ class EventPushActionsStore(SQLBaseStore): | |||||||
|                 " stream_ordering >= ? AND stream_ordering <= ?" |                 " stream_ordering >= ? AND stream_ordering <= ?" | ||||||
|             ) |             ) | ||||||
|             txn.execute(sql, (min_stream_ordering, max_stream_ordering)) |             txn.execute(sql, (min_stream_ordering, max_stream_ordering)) | ||||||
|             return [r[0] for r in txn.fetchall()] |             return [r[0] for r in txn] | ||||||
|         ret = yield self.runInteraction("get_push_action_users_in_range", f) |         ret = yield self.runInteraction("get_push_action_users_in_range", f) | ||||||
|         defer.returnValue(ret) |         defer.returnValue(ret) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -34,14 +34,16 @@ from canonicaljson import encode_canonical_json | |||||||
| from collections import deque, namedtuple, OrderedDict | from collections import deque, namedtuple, OrderedDict | ||||||
| from functools import wraps | from functools import wraps | ||||||
| 
 | 
 | ||||||
| import synapse |  | ||||||
| import synapse.metrics | import synapse.metrics | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| import logging | import logging | ||||||
| import math | import math | ||||||
| import ujson as json | import ujson as json | ||||||
| 
 | 
 | ||||||
|  | # these are only included to make the type annotations work | ||||||
|  | from synapse.events import EventBase    # noqa: F401 | ||||||
|  | from synapse.events.snapshot import EventContext   # noqa: F401 | ||||||
|  | 
 | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -82,6 +84,11 @@ class _EventPeristenceQueue(object): | |||||||
| 
 | 
 | ||||||
|     def add_to_queue(self, room_id, events_and_contexts, backfilled): |     def add_to_queue(self, room_id, events_and_contexts, backfilled): | ||||||
|         """Add events to the queue, with the given persist_event options. |         """Add events to the queue, with the given persist_event options. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             room_id (str): | ||||||
|  |             events_and_contexts (list[(EventBase, EventContext)]): | ||||||
|  |             backfilled (bool): | ||||||
|         """ |         """ | ||||||
|         queue = self._event_persist_queues.setdefault(room_id, deque()) |         queue = self._event_persist_queues.setdefault(room_id, deque()) | ||||||
|         if queue: |         if queue: | ||||||
| @ -210,14 +217,14 @@ class EventsStore(SQLBaseStore): | |||||||
|             partitioned.setdefault(event.room_id, []).append((event, ctx)) |             partitioned.setdefault(event.room_id, []).append((event, ctx)) | ||||||
| 
 | 
 | ||||||
|         deferreds = [] |         deferreds = [] | ||||||
|         for room_id, evs_ctxs in partitioned.items(): |         for room_id, evs_ctxs in partitioned.iteritems(): | ||||||
|             d = preserve_fn(self._event_persist_queue.add_to_queue)( |             d = preserve_fn(self._event_persist_queue.add_to_queue)( | ||||||
|                 room_id, evs_ctxs, |                 room_id, evs_ctxs, | ||||||
|                 backfilled=backfilled, |                 backfilled=backfilled, | ||||||
|             ) |             ) | ||||||
|             deferreds.append(d) |             deferreds.append(d) | ||||||
| 
 | 
 | ||||||
|         for room_id in partitioned.keys(): |         for room_id in partitioned: | ||||||
|             self._maybe_start_persisting(room_id) |             self._maybe_start_persisting(room_id) | ||||||
| 
 | 
 | ||||||
|         return preserve_context_over_deferred( |         return preserve_context_over_deferred( | ||||||
| @ -227,6 +234,17 @@ class EventsStore(SQLBaseStore): | |||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     @log_function |     @log_function | ||||||
|     def persist_event(self, event, context, backfilled=False): |     def persist_event(self, event, context, backfilled=False): | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             event (EventBase): | ||||||
|  |             context (EventContext): | ||||||
|  |             backfilled (bool): | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             Deferred: resolves to (int, int): the stream ordering of ``event``, | ||||||
|  |             and the stream ordering of the latest persisted event | ||||||
|  |         """ | ||||||
|         deferred = self._event_persist_queue.add_to_queue( |         deferred = self._event_persist_queue.add_to_queue( | ||||||
|             event.room_id, [(event, context)], |             event.room_id, [(event, context)], | ||||||
|             backfilled=backfilled, |             backfilled=backfilled, | ||||||
| @ -253,6 +271,16 @@ class EventsStore(SQLBaseStore): | |||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def _persist_events(self, events_and_contexts, backfilled=False, |     def _persist_events(self, events_and_contexts, backfilled=False, | ||||||
|                         delete_existing=False): |                         delete_existing=False): | ||||||
|  |         """Persist events to db | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             events_and_contexts (list[(EventBase, EventContext)]): | ||||||
|  |             backfilled (bool): | ||||||
|  |             delete_existing (bool): | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             Deferred: resolves when the events have been persisted | ||||||
|  |         """ | ||||||
|         if not events_and_contexts: |         if not events_and_contexts: | ||||||
|             return |             return | ||||||
| 
 | 
 | ||||||
| @ -295,7 +323,7 @@ class EventsStore(SQLBaseStore): | |||||||
|                                 (event, context) |                                 (event, context) | ||||||
|                             ) |                             ) | ||||||
| 
 | 
 | ||||||
|                         for room_id, ev_ctx_rm in events_by_room.items(): |                         for room_id, ev_ctx_rm in events_by_room.iteritems(): | ||||||
|                             # Work out new extremities by recursively adding and removing |                             # Work out new extremities by recursively adding and removing | ||||||
|                             # the new events. |                             # the new events. | ||||||
|                             latest_event_ids = yield self.get_latest_event_ids_in_room( |                             latest_event_ids = yield self.get_latest_event_ids_in_room( | ||||||
| @ -400,6 +428,7 @@ class EventsStore(SQLBaseStore): | |||||||
|         # Now we need to work out the different state sets for |         # Now we need to work out the different state sets for | ||||||
|         # each state extremities |         # each state extremities | ||||||
|         state_sets = [] |         state_sets = [] | ||||||
|  |         state_groups = set() | ||||||
|         missing_event_ids = [] |         missing_event_ids = [] | ||||||
|         was_updated = False |         was_updated = False | ||||||
|         for event_id in new_latest_event_ids: |         for event_id in new_latest_event_ids: | ||||||
| @ -409,9 +438,17 @@ class EventsStore(SQLBaseStore): | |||||||
|                 if event_id == ev.event_id: |                 if event_id == ev.event_id: | ||||||
|                     if ctx.current_state_ids is None: |                     if ctx.current_state_ids is None: | ||||||
|                         raise Exception("Unknown current state") |                         raise Exception("Unknown current state") | ||||||
|  | 
 | ||||||
|  |                     # If we've already seen the state group don't bother adding | ||||||
|  |                     # it to the state sets again | ||||||
|  |                     if ctx.state_group not in state_groups: | ||||||
|                         state_sets.append(ctx.current_state_ids) |                         state_sets.append(ctx.current_state_ids) | ||||||
|                         if ctx.delta_ids or hasattr(ev, "state_key"): |                         if ctx.delta_ids or hasattr(ev, "state_key"): | ||||||
|                             was_updated = True |                             was_updated = True | ||||||
|  |                         if ctx.state_group: | ||||||
|  |                             # Add this as a seen state group (if it has a state | ||||||
|  |                             # group) | ||||||
|  |                             state_groups.add(ctx.state_group) | ||||||
|                     break |                     break | ||||||
|             else: |             else: | ||||||
|                 # If we couldn't find it, then we'll need to pull |                 # If we couldn't find it, then we'll need to pull | ||||||
| @ -425,31 +462,57 @@ class EventsStore(SQLBaseStore): | |||||||
|                 missing_event_ids, |                 missing_event_ids, | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|             groups = set(event_to_groups.values()) |             groups = set(event_to_groups.itervalues()) - state_groups | ||||||
|             group_to_state = yield self._get_state_for_groups(groups) |  | ||||||
| 
 | 
 | ||||||
|             state_sets.extend(group_to_state.values()) |             if groups: | ||||||
|  |                 group_to_state = yield self._get_state_for_groups(groups) | ||||||
|  |                 state_sets.extend(group_to_state.itervalues()) | ||||||
| 
 | 
 | ||||||
|         if not new_latest_event_ids: |         if not new_latest_event_ids: | ||||||
|             current_state = {} |             current_state = {} | ||||||
|         elif was_updated: |         elif was_updated: | ||||||
|  |             if len(state_sets) == 1: | ||||||
|  |                 # If there is only one state set, then we know what the current | ||||||
|  |                 # state is. | ||||||
|  |                 current_state = state_sets[0] | ||||||
|  |             else: | ||||||
|  |                 # We work out the current state by passing the state sets to the | ||||||
|  |                 # state resolution algorithm. It may ask for some events, including | ||||||
|  |                 # the events we have yet to persist, so we need a slightly more | ||||||
|  |                 # complicated event lookup function than simply looking the events | ||||||
|  |                 # up in the db. | ||||||
|  |                 events_map = {ev.event_id: ev for ev, _ in events_context} | ||||||
|  | 
 | ||||||
|  |                 @defer.inlineCallbacks | ||||||
|  |                 def get_events(ev_ids): | ||||||
|  |                     # We get the events by first looking at the list of events we | ||||||
|  |                     # are trying to persist, and then fetching the rest from the DB. | ||||||
|  |                     db = [] | ||||||
|  |                     to_return = {} | ||||||
|  |                     for ev_id in ev_ids: | ||||||
|  |                         ev = events_map.get(ev_id, None) | ||||||
|  |                         if ev: | ||||||
|  |                             to_return[ev_id] = ev | ||||||
|  |                         else: | ||||||
|  |                             db.append(ev_id) | ||||||
|  | 
 | ||||||
|  |                     if db: | ||||||
|  |                         evs = yield self.get_events( | ||||||
|  |                             ev_ids, get_prev_content=False, check_redacted=False, | ||||||
|  |                         ) | ||||||
|  |                         to_return.update(evs) | ||||||
|  |                     defer.returnValue(to_return) | ||||||
|  | 
 | ||||||
|                 current_state = yield resolve_events( |                 current_state = yield resolve_events( | ||||||
|                     state_sets, |                     state_sets, | ||||||
|                 state_map_factory=lambda ev_ids: self.get_events( |                     state_map_factory=get_events, | ||||||
|                     ev_ids, get_prev_content=False, check_redacted=False, |  | ||||||
|                 ), |  | ||||||
|                 ) |                 ) | ||||||
|         else: |         else: | ||||||
|             return |             return | ||||||
| 
 | 
 | ||||||
|         existing_state_rows = yield self._simple_select_list( |         existing_state = yield self.get_current_state_ids(room_id) | ||||||
|             table="current_state_events", |  | ||||||
|             keyvalues={"room_id": room_id}, |  | ||||||
|             retcols=["event_id", "type", "state_key"], |  | ||||||
|             desc="_calculate_state_delta", |  | ||||||
|         ) |  | ||||||
| 
 | 
 | ||||||
|         existing_events = set(row["event_id"] for row in existing_state_rows) |         existing_events = set(existing_state.itervalues()) | ||||||
|         new_events = set(ev_id for ev_id in current_state.itervalues()) |         new_events = set(ev_id for ev_id in current_state.itervalues()) | ||||||
|         changed_events = existing_events ^ new_events |         changed_events = existing_events ^ new_events | ||||||
| 
 | 
 | ||||||
| @ -457,9 +520,8 @@ class EventsStore(SQLBaseStore): | |||||||
|             return |             return | ||||||
| 
 | 
 | ||||||
|         to_delete = { |         to_delete = { | ||||||
|             (row["type"], row["state_key"]): row["event_id"] |             key: ev_id for key, ev_id in existing_state.iteritems() | ||||||
|             for row in existing_state_rows |             if ev_id in changed_events | ||||||
|             if row["event_id"] in changed_events |  | ||||||
|         } |         } | ||||||
|         events_to_insert = (new_events - existing_events) |         events_to_insert = (new_events - existing_events) | ||||||
|         to_insert = { |         to_insert = { | ||||||
| @ -535,11 +597,91 @@ class EventsStore(SQLBaseStore): | |||||||
|         and the rejections table. Things reading from those table will need to check |         and the rejections table. Things reading from those table will need to check | ||||||
|         whether the event was rejected. |         whether the event was rejected. | ||||||
| 
 | 
 | ||||||
|         If delete_existing is True then existing events will be purged from the |         Args: | ||||||
|         database before insertion. This is useful when retrying due to IntegrityError. |             txn (twisted.enterprise.adbapi.Connection): db connection | ||||||
|  |             events_and_contexts (list[(EventBase, EventContext)]): | ||||||
|  |                 events to persist | ||||||
|  |             backfilled (bool): True if the events were backfilled | ||||||
|  |             delete_existing (bool): True to purge existing table rows for the | ||||||
|  |                 events from the database. This is useful when retrying due to | ||||||
|  |                 IntegrityError. | ||||||
|  |             current_state_for_room (dict[str, (list[str], list[str])]): | ||||||
|  |                 The current-state delta for each room. For each room, a tuple | ||||||
|  |                 (to_delete, to_insert), being a list of event ids to be removed | ||||||
|  |                 from the current state, and a list of event ids to be added to | ||||||
|  |                 the current state. | ||||||
|  |             new_forward_extremeties (dict[str, list[str]]): | ||||||
|  |                 The new forward extremities for each room. For each room, a | ||||||
|  |                 list of the event ids which are the forward extremities. | ||||||
|  | 
 | ||||||
|         """ |         """ | ||||||
|  |         self._update_current_state_txn(txn, current_state_for_room) | ||||||
|  | 
 | ||||||
|         max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering |         max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering | ||||||
|         for room_id, current_state_tuple in current_state_for_room.iteritems(): |         self._update_forward_extremities_txn( | ||||||
|  |             txn, | ||||||
|  |             new_forward_extremities=new_forward_extremeties, | ||||||
|  |             max_stream_order=max_stream_order, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         # Ensure that we don't have the same event twice. | ||||||
|  |         events_and_contexts = self._filter_events_and_contexts_for_duplicates( | ||||||
|  |             events_and_contexts, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         self._update_room_depths_txn( | ||||||
|  |             txn, | ||||||
|  |             events_and_contexts=events_and_contexts, | ||||||
|  |             backfilled=backfilled, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         # _update_outliers_txn filters out any events which have already been | ||||||
|  |         # persisted, and returns the filtered list. | ||||||
|  |         events_and_contexts = self._update_outliers_txn( | ||||||
|  |             txn, | ||||||
|  |             events_and_contexts=events_and_contexts, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         # From this point onwards the events are only events that we haven't | ||||||
|  |         # seen before. | ||||||
|  | 
 | ||||||
|  |         if delete_existing: | ||||||
|  |             # For paranoia reasons, we go and delete all the existing entries | ||||||
|  |             # for these events so we can reinsert them. | ||||||
|  |             # This gets around any problems with some tables already having | ||||||
|  |             # entries. | ||||||
|  |             self._delete_existing_rows_txn( | ||||||
|  |                 txn, | ||||||
|  |                 events_and_contexts=events_and_contexts, | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         self._store_event_txn( | ||||||
|  |             txn, | ||||||
|  |             events_and_contexts=events_and_contexts, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         # Insert into the state_groups, state_groups_state, and | ||||||
|  |         # event_to_state_groups tables. | ||||||
|  |         self._store_mult_state_groups_txn(txn, events_and_contexts) | ||||||
|  | 
 | ||||||
|  |         # _store_rejected_events_txn filters out any events which were | ||||||
|  |         # rejected, and returns the filtered list. | ||||||
|  |         events_and_contexts = self._store_rejected_events_txn( | ||||||
|  |             txn, | ||||||
|  |             events_and_contexts=events_and_contexts, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         # From this point onwards the events are only ones that weren't | ||||||
|  |         # rejected. | ||||||
|  | 
 | ||||||
|  |         self._update_metadata_tables_txn( | ||||||
|  |             txn, | ||||||
|  |             events_and_contexts=events_and_contexts, | ||||||
|  |             backfilled=backfilled, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     def _update_current_state_txn(self, txn, state_delta_by_room): | ||||||
|  |         for room_id, current_state_tuple in state_delta_by_room.iteritems(): | ||||||
|                 to_delete, to_insert = current_state_tuple |                 to_delete, to_insert = current_state_tuple | ||||||
|                 txn.executemany( |                 txn.executemany( | ||||||
|                     "DELETE FROM current_state_events WHERE event_id = ?", |                     "DELETE FROM current_state_events WHERE event_id = ?", | ||||||
| @ -585,7 +727,13 @@ class EventsStore(SQLBaseStore): | |||||||
|                     txn, self.get_users_in_room, (room_id,) |                     txn, self.get_users_in_room, (room_id,) | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|         for room_id, new_extrem in new_forward_extremeties.items(): |                 self._invalidate_cache_and_stream( | ||||||
|  |                     txn, self.get_current_state_ids, (room_id,) | ||||||
|  |                 ) | ||||||
|  | 
 | ||||||
|  |     def _update_forward_extremities_txn(self, txn, new_forward_extremities, | ||||||
|  |                                         max_stream_order): | ||||||
|  |         for room_id, new_extrem in new_forward_extremities.iteritems(): | ||||||
|             self._simple_delete_txn( |             self._simple_delete_txn( | ||||||
|                 txn, |                 txn, | ||||||
|                 table="event_forward_extremities", |                 table="event_forward_extremities", | ||||||
| @ -603,7 +751,7 @@ class EventsStore(SQLBaseStore): | |||||||
|                     "event_id": ev_id, |                     "event_id": ev_id, | ||||||
|                     "room_id": room_id, |                     "room_id": room_id, | ||||||
|                 } |                 } | ||||||
|                 for room_id, new_extrem in new_forward_extremeties.items() |                 for room_id, new_extrem in new_forward_extremities.iteritems() | ||||||
|                 for ev_id in new_extrem |                 for ev_id in new_extrem | ||||||
|             ], |             ], | ||||||
|         ) |         ) | ||||||
| @ -620,13 +768,22 @@ class EventsStore(SQLBaseStore): | |||||||
|                     "event_id": event_id, |                     "event_id": event_id, | ||||||
|                     "stream_ordering": max_stream_order, |                     "stream_ordering": max_stream_order, | ||||||
|                 } |                 } | ||||||
|                 for room_id, new_extrem in new_forward_extremeties.items() |                 for room_id, new_extrem in new_forward_extremities.iteritems() | ||||||
|                 for event_id in new_extrem |                 for event_id in new_extrem | ||||||
|             ] |             ] | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         # Ensure that we don't have the same event twice. |     @classmethod | ||||||
|         # Pick the earliest non-outlier if there is one, else the earliest one. |     def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts): | ||||||
|  |         """Ensure that we don't have the same event twice. | ||||||
|  | 
 | ||||||
|  |         Pick the earliest non-outlier if there is one, else the earliest one. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             events_and_contexts (list[(EventBase, EventContext)]): | ||||||
|  |         Returns: | ||||||
|  |             list[(EventBase, EventContext)]: filtered list | ||||||
|  |         """ | ||||||
|         new_events_and_contexts = OrderedDict() |         new_events_and_contexts = OrderedDict() | ||||||
|         for event, context in events_and_contexts: |         for event, context in events_and_contexts: | ||||||
|             prev_event_context = new_events_and_contexts.get(event.event_id) |             prev_event_context = new_events_and_contexts.get(event.event_id) | ||||||
| @ -639,9 +796,17 @@ class EventsStore(SQLBaseStore): | |||||||
|                         new_events_and_contexts[event.event_id] = (event, context) |                         new_events_and_contexts[event.event_id] = (event, context) | ||||||
|             else: |             else: | ||||||
|                 new_events_and_contexts[event.event_id] = (event, context) |                 new_events_and_contexts[event.event_id] = (event, context) | ||||||
|  |         return new_events_and_contexts.values() | ||||||
| 
 | 
 | ||||||
|         events_and_contexts = new_events_and_contexts.values() |     def _update_room_depths_txn(self, txn, events_and_contexts, backfilled): | ||||||
|  |         """Update min_depth for each room | ||||||
| 
 | 
 | ||||||
|  |         Args: | ||||||
|  |             txn (twisted.enterprise.adbapi.Connection): db connection | ||||||
|  |             events_and_contexts (list[(EventBase, EventContext)]): events | ||||||
|  |                 we are persisting | ||||||
|  |             backfilled (bool): True if the events were backfilled | ||||||
|  |         """ | ||||||
|         depth_updates = {} |         depth_updates = {} | ||||||
|         for event, context in events_and_contexts: |         for event, context in events_and_contexts: | ||||||
|             # Remove the any existing cache entries for the event_ids |             # Remove the any existing cache entries for the event_ids | ||||||
| @ -657,9 +822,24 @@ class EventsStore(SQLBaseStore): | |||||||
|                     event.depth, depth_updates.get(event.room_id, event.depth) |                     event.depth, depth_updates.get(event.room_id, event.depth) | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|         for room_id, depth in depth_updates.items(): |         for room_id, depth in depth_updates.iteritems(): | ||||||
|             self._update_min_depth_for_room_txn(txn, room_id, depth) |             self._update_min_depth_for_room_txn(txn, room_id, depth) | ||||||
| 
 | 
 | ||||||
|  |     def _update_outliers_txn(self, txn, events_and_contexts): | ||||||
|  |         """Update any outliers with new event info. | ||||||
|  | 
 | ||||||
|  |         This turns outliers into ex-outliers (unless the new event was | ||||||
|  |         rejected). | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             txn (twisted.enterprise.adbapi.Connection): db connection | ||||||
|  |             events_and_contexts (list[(EventBase, EventContext)]): events | ||||||
|  |                 we are persisting | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             list[(EventBase, EventContext)] new list, without events which | ||||||
|  |             are already in the events table. | ||||||
|  |         """ | ||||||
|         txn.execute( |         txn.execute( | ||||||
|             "SELECT event_id, outlier FROM events WHERE event_id in (%s)" % ( |             "SELECT event_id, outlier FROM events WHERE event_id in (%s)" % ( | ||||||
|                 ",".join(["?"] * len(events_and_contexts)), |                 ",".join(["?"] * len(events_and_contexts)), | ||||||
| @ -669,24 +849,21 @@ class EventsStore(SQLBaseStore): | |||||||
| 
 | 
 | ||||||
|         have_persisted = { |         have_persisted = { | ||||||
|             event_id: outlier |             event_id: outlier | ||||||
|             for event_id, outlier in txn.fetchall() |             for event_id, outlier in txn | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         to_remove = set() |         to_remove = set() | ||||||
|         for event, context in events_and_contexts: |         for event, context in events_and_contexts: | ||||||
|             if context.rejected: |  | ||||||
|                 # If the event is rejected then we don't care if the event |  | ||||||
|                 # was an outlier or not. |  | ||||||
|                 if event.event_id in have_persisted: |  | ||||||
|                     # If we have already seen the event then ignore it. |  | ||||||
|                     to_remove.add(event) |  | ||||||
|                 continue |  | ||||||
| 
 |  | ||||||
|             if event.event_id not in have_persisted: |             if event.event_id not in have_persisted: | ||||||
|                 continue |                 continue | ||||||
| 
 | 
 | ||||||
|             to_remove.add(event) |             to_remove.add(event) | ||||||
| 
 | 
 | ||||||
|  |             if context.rejected: | ||||||
|  |                 # If the event is rejected then we don't care if the event | ||||||
|  |                 # was an outlier or not. | ||||||
|  |                 continue | ||||||
|  | 
 | ||||||
|             outlier_persisted = have_persisted[event.event_id] |             outlier_persisted = have_persisted[event.event_id] | ||||||
|             if not event.internal_metadata.is_outlier() and outlier_persisted: |             if not event.internal_metadata.is_outlier() and outlier_persisted: | ||||||
|                 # We received a copy of an event that we had already stored as |                 # We received a copy of an event that we had already stored as | ||||||
| @ -741,34 +918,16 @@ class EventsStore(SQLBaseStore): | |||||||
|                 # event isn't an outlier any more. |                 # event isn't an outlier any more. | ||||||
|                 self._update_backward_extremeties(txn, [event]) |                 self._update_backward_extremeties(txn, [event]) | ||||||
| 
 | 
 | ||||||
|         events_and_contexts = [ |         return [ | ||||||
|             ec for ec in events_and_contexts if ec[0] not in to_remove |             ec for ec in events_and_contexts if ec[0] not in to_remove | ||||||
|         ] |         ] | ||||||
| 
 | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def _delete_existing_rows_txn(cls, txn, events_and_contexts): | ||||||
|         if not events_and_contexts: |         if not events_and_contexts: | ||||||
|             # Make sure we don't pass an empty list to functions that expect to |             # nothing to do here | ||||||
|             # be storing at least one element. |  | ||||||
|             return |             return | ||||||
| 
 | 
 | ||||||
|         # From this point onwards the events are only events that we haven't |  | ||||||
|         # seen before. |  | ||||||
| 
 |  | ||||||
|         def event_dict(event): |  | ||||||
|             return { |  | ||||||
|                 k: v |  | ||||||
|                 for k, v in event.get_dict().items() |  | ||||||
|                 if k not in [ |  | ||||||
|                     "redacted", |  | ||||||
|                     "redacted_because", |  | ||||||
|                 ] |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|         if delete_existing: |  | ||||||
|             # For paranoia reasons, we go and delete all the existing entries |  | ||||||
|             # for these events so we can reinsert them. |  | ||||||
|             # This gets around any problems with some tables already having |  | ||||||
|             # entries. |  | ||||||
| 
 |  | ||||||
|         logger.info("Deleting existing") |         logger.info("Deleting existing") | ||||||
| 
 | 
 | ||||||
|         for table in ( |         for table in ( | ||||||
| @ -800,6 +959,25 @@ class EventsStore(SQLBaseStore): | |||||||
|                 [(ev.event_id,) for ev, _ in events_and_contexts] |                 [(ev.event_id,) for ev, _ in events_and_contexts] | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|  |     def _store_event_txn(self, txn, events_and_contexts): | ||||||
|  |         """Insert new events into the event and event_json tables | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             txn (twisted.enterprise.adbapi.Connection): db connection | ||||||
|  |             events_and_contexts (list[(EventBase, EventContext)]): events | ||||||
|  |                 we are persisting | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |         if not events_and_contexts: | ||||||
|  |             # nothing to do here | ||||||
|  |             return | ||||||
|  | 
 | ||||||
|  |         def event_dict(event): | ||||||
|  |             d = event.get_dict() | ||||||
|  |             d.pop("redacted", None) | ||||||
|  |             d.pop("redacted_because", None) | ||||||
|  |             return d | ||||||
|  | 
 | ||||||
|         self._simple_insert_many_txn( |         self._simple_insert_many_txn( | ||||||
|             txn, |             txn, | ||||||
|             table="event_json", |             table="event_json", | ||||||
| @ -842,6 +1020,19 @@ class EventsStore(SQLBaseStore): | |||||||
|             ], |             ], | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  |     def _store_rejected_events_txn(self, txn, events_and_contexts): | ||||||
|  |         """Add rows to the 'rejections' table for received events which were | ||||||
|  |         rejected | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             txn (twisted.enterprise.adbapi.Connection): db connection | ||||||
|  |             events_and_contexts (list[(EventBase, EventContext)]): events | ||||||
|  |                 we are persisting | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             list[(EventBase, EventContext)] new list, without the rejected | ||||||
|  |                 events. | ||||||
|  |         """ | ||||||
|         # Remove the rejected events from the list now that we've added them |         # Remove the rejected events from the list now that we've added them | ||||||
|         # to the events table and the events_json table. |         # to the events table and the events_json table. | ||||||
|         to_remove = set() |         to_remove = set() | ||||||
| @ -853,16 +1044,23 @@ class EventsStore(SQLBaseStore): | |||||||
|                 ) |                 ) | ||||||
|                 to_remove.add(event) |                 to_remove.add(event) | ||||||
| 
 | 
 | ||||||
|         events_and_contexts = [ |         return [ | ||||||
|             ec for ec in events_and_contexts if ec[0] not in to_remove |             ec for ec in events_and_contexts if ec[0] not in to_remove | ||||||
|         ] |         ] | ||||||
| 
 | 
 | ||||||
|         if not events_and_contexts: |     def _update_metadata_tables_txn(self, txn, events_and_contexts, backfilled): | ||||||
|             # Make sure we don't pass an empty list to functions that expect to |         """Update all the miscellaneous tables for new events | ||||||
|             # be storing at least one element. |  | ||||||
|             return |  | ||||||
| 
 | 
 | ||||||
|         # From this point onwards the events are only ones that weren't rejected. |         Args: | ||||||
|  |             txn (twisted.enterprise.adbapi.Connection): db connection | ||||||
|  |             events_and_contexts (list[(EventBase, EventContext)]): events | ||||||
|  |                 we are persisting | ||||||
|  |             backfilled (bool): True if the events were backfilled | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |         if not events_and_contexts: | ||||||
|  |             # nothing to do here | ||||||
|  |             return | ||||||
| 
 | 
 | ||||||
|         for event, context in events_and_contexts: |         for event, context in events_and_contexts: | ||||||
|             # Insert all the push actions into the event_push_actions table. |             # Insert all the push actions into the event_push_actions table. | ||||||
| @ -892,10 +1090,6 @@ class EventsStore(SQLBaseStore): | |||||||
|             ], |             ], | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         # Insert into the state_groups, state_groups_state, and |  | ||||||
|         # event_to_state_groups tables. |  | ||||||
|         self._store_mult_state_groups_txn(txn, events_and_contexts) |  | ||||||
| 
 |  | ||||||
|         # Update the event_forward_extremities, event_backward_extremities and |         # Update the event_forward_extremities, event_backward_extremities and | ||||||
|         # event_edges tables. |         # event_edges tables. | ||||||
|         self._handle_mult_prev_events( |         self._handle_mult_prev_events( | ||||||
| @ -982,13 +1176,6 @@ class EventsStore(SQLBaseStore): | |||||||
|         # Prefill the event cache |         # Prefill the event cache | ||||||
|         self._add_to_cache(txn, events_and_contexts) |         self._add_to_cache(txn, events_and_contexts) | ||||||
| 
 | 
 | ||||||
|         if backfilled: |  | ||||||
|             # Backfilled events come before the current state so we don't need |  | ||||||
|             # to update the current state table |  | ||||||
|             return |  | ||||||
| 
 |  | ||||||
|         return |  | ||||||
| 
 |  | ||||||
|     def _add_to_cache(self, txn, events_and_contexts): |     def _add_to_cache(self, txn, events_and_contexts): | ||||||
|         to_prefill = [] |         to_prefill = [] | ||||||
| 
 | 
 | ||||||
| @ -1597,14 +1784,13 @@ class EventsStore(SQLBaseStore): | |||||||
| 
 | 
 | ||||||
|         def get_all_new_events_txn(txn): |         def get_all_new_events_txn(txn): | ||||||
|             sql = ( |             sql = ( | ||||||
|                 "SELECT e.stream_ordering, ej.internal_metadata, ej.json, eg.state_group" |                 "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," | ||||||
|                 " FROM events as e" |                 " state_key, redacts" | ||||||
|                 " JOIN event_json as ej" |                 " FROM events AS e" | ||||||
|                 " ON e.event_id = ej.event_id AND e.room_id = ej.room_id" |                 " LEFT JOIN redactions USING (event_id)" | ||||||
|                 " LEFT JOIN event_to_state_groups as eg" |                 " LEFT JOIN state_events USING (event_id)" | ||||||
|                 " ON e.event_id = eg.event_id" |                 " WHERE ? < stream_ordering AND stream_ordering <= ?" | ||||||
|                 " WHERE ? < e.stream_ordering AND e.stream_ordering <= ?" |                 " ORDER BY stream_ordering ASC" | ||||||
|                 " ORDER BY e.stream_ordering ASC" |  | ||||||
|                 " LIMIT ?" |                 " LIMIT ?" | ||||||
|             ) |             ) | ||||||
|             if have_forward_events: |             if have_forward_events: | ||||||
| @ -1630,15 +1816,13 @@ class EventsStore(SQLBaseStore): | |||||||
|                 forward_ex_outliers = [] |                 forward_ex_outliers = [] | ||||||
| 
 | 
 | ||||||
|             sql = ( |             sql = ( | ||||||
|                 "SELECT -e.stream_ordering, ej.internal_metadata, ej.json," |                 "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," | ||||||
|                 " eg.state_group" |                 " state_key, redacts" | ||||||
|                 " FROM events as e" |                 " FROM events AS e" | ||||||
|                 " JOIN event_json as ej" |                 " LEFT JOIN redactions USING (event_id)" | ||||||
|                 " ON e.event_id = ej.event_id AND e.room_id = ej.room_id" |                 " LEFT JOIN state_events USING (event_id)" | ||||||
|                 " LEFT JOIN event_to_state_groups as eg" |                 " WHERE ? > stream_ordering AND stream_ordering >= ?" | ||||||
|                 " ON e.event_id = eg.event_id" |                 " ORDER BY stream_ordering DESC" | ||||||
|                 " WHERE ? > e.stream_ordering AND e.stream_ordering >= ?" |  | ||||||
|                 " ORDER BY e.stream_ordering DESC" |  | ||||||
|                 " LIMIT ?" |                 " LIMIT ?" | ||||||
|             ) |             ) | ||||||
|             if have_backfill_events: |             if have_backfill_events: | ||||||
| @ -1825,7 +2009,7 @@ class EventsStore(SQLBaseStore): | |||||||
|                         "state_key": key[1], |                         "state_key": key[1], | ||||||
|                         "event_id": state_id, |                         "event_id": state_id, | ||||||
|                     } |                     } | ||||||
|                     for key, state_id in curr_state.items() |                     for key, state_id in curr_state.iteritems() | ||||||
|                 ], |                 ], | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -101,9 +101,10 @@ class KeyStore(SQLBaseStore): | |||||||
|         key_ids |         key_ids | ||||||
|         Args: |         Args: | ||||||
|             server_name (str): The name of the server. |             server_name (str): The name of the server. | ||||||
|             key_ids (list of str): List of key_ids to try and look up. |             key_ids (iterable[str]): key_ids to try and look up. | ||||||
|         Returns: |         Returns: | ||||||
|             (list of VerifyKey): The verification keys. |             Deferred: resolves to dict[str, VerifyKey]: map from | ||||||
|  |                key_id to verification key. | ||||||
|         """ |         """ | ||||||
|         keys = {} |         keys = {} | ||||||
|         for key_id in key_ids: |         for key_id in key_ids: | ||||||
|  | |||||||
| @ -356,7 +356,7 @@ def _get_or_create_schema_state(txn, database_engine): | |||||||
|             ), |             ), | ||||||
|             (current_version,) |             (current_version,) | ||||||
|         ) |         ) | ||||||
|         applied_deltas = [d for d, in txn.fetchall()] |         applied_deltas = [d for d, in txn] | ||||||
|         return current_version, applied_deltas, upgraded |         return current_version, applied_deltas, upgraded | ||||||
| 
 | 
 | ||||||
|     return None |     return None | ||||||
|  | |||||||
| @ -85,8 +85,8 @@ class PresenceStore(SQLBaseStore): | |||||||
|                 self.presence_stream_cache.entity_has_changed, |                 self.presence_stream_cache.entity_has_changed, | ||||||
|                 state.user_id, stream_id, |                 state.user_id, stream_id, | ||||||
|             ) |             ) | ||||||
|             self._invalidate_cache_and_stream( |             txn.call_after( | ||||||
|                 txn, self._get_presence_for_user, (state.user_id,) |                 self._get_presence_for_user.invalidate, (state.user_id,) | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         # Actually insert new rows |         # Actually insert new rows | ||||||
|  | |||||||
| @ -313,10 +313,9 @@ class ReceiptsStore(SQLBaseStore): | |||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         txn.execute(sql, (room_id, receipt_type, user_id)) |         txn.execute(sql, (room_id, receipt_type, user_id)) | ||||||
|         results = txn.fetchall() |  | ||||||
| 
 | 
 | ||||||
|         if results and topological_ordering: |         if topological_ordering: | ||||||
|             for to, so, _ in results: |             for to, so, _ in txn: | ||||||
|                 if int(to) > topological_ordering: |                 if int(to) > topological_ordering: | ||||||
|                     return False |                     return False | ||||||
|                 elif int(to) == topological_ordering and int(so) >= stream_ordering: |                 elif int(to) == topological_ordering and int(so) >= stream_ordering: | ||||||
|  | |||||||
| @ -209,7 +209,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): | |||||||
|                 " WHERE lower(name) = lower(?)" |                 " WHERE lower(name) = lower(?)" | ||||||
|             ) |             ) | ||||||
|             txn.execute(sql, (user_id,)) |             txn.execute(sql, (user_id,)) | ||||||
|             return dict(txn.fetchall()) |             return dict(txn) | ||||||
| 
 | 
 | ||||||
|         return self.runInteraction("get_users_by_id_case_insensitive", f) |         return self.runInteraction("get_users_by_id_case_insensitive", f) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -396,7 +396,7 @@ class RoomStore(SQLBaseStore): | |||||||
|                     sql % ("AND appservice_id IS NULL",), |                     sql % ("AND appservice_id IS NULL",), | ||||||
|                     (stream_id,) |                     (stream_id,) | ||||||
|                 ) |                 ) | ||||||
|             return dict(txn.fetchall()) |             return dict(txn) | ||||||
|         else: |         else: | ||||||
|             # We want to get from all lists, so we need to aggregate the results |             # We want to get from all lists, so we need to aggregate the results | ||||||
| 
 | 
 | ||||||
| @ -422,7 +422,7 @@ class RoomStore(SQLBaseStore): | |||||||
| 
 | 
 | ||||||
|             results = {} |             results = {} | ||||||
|             # A room is visible if its visible on any list. |             # A room is visible if its visible on any list. | ||||||
|             for room_id, visibility in txn.fetchall(): |             for room_id, visibility in txn: | ||||||
|                 results[room_id] = bool(visibility) or results.get(room_id, False) |                 results[room_id] = bool(visibility) or results.get(room_id, False) | ||||||
| 
 | 
 | ||||||
|             return results |             return results | ||||||
|  | |||||||
| @ -129,17 +129,30 @@ class RoomMemberStore(SQLBaseStore): | |||||||
|         with self._stream_id_gen.get_next() as stream_ordering: |         with self._stream_id_gen.get_next() as stream_ordering: | ||||||
|             yield self.runInteraction("locally_reject_invite", f, stream_ordering) |             yield self.runInteraction("locally_reject_invite", f, stream_ordering) | ||||||
| 
 | 
 | ||||||
|  |     @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True) | ||||||
|  |     def get_hosts_in_room(self, room_id, cache_context): | ||||||
|  |         """Returns the set of all hosts currently in the room | ||||||
|  |         """ | ||||||
|  |         user_ids = yield self.get_users_in_room( | ||||||
|  |             room_id, on_invalidate=cache_context.invalidate, | ||||||
|  |         ) | ||||||
|  |         hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids) | ||||||
|  |         defer.returnValue(hosts) | ||||||
|  | 
 | ||||||
|     @cached(max_entries=500000, iterable=True) |     @cached(max_entries=500000, iterable=True) | ||||||
|     def get_users_in_room(self, room_id): |     def get_users_in_room(self, room_id): | ||||||
|         def f(txn): |         def f(txn): | ||||||
| 
 |             sql = ( | ||||||
|             rows = self._get_members_rows_txn( |                 "SELECT m.user_id FROM room_memberships as m" | ||||||
|                 txn, |                 " INNER JOIN current_state_events as c" | ||||||
|                 room_id=room_id, |                 " ON m.event_id = c.event_id " | ||||||
|                 membership=Membership.JOIN, |                 " AND m.room_id = c.room_id " | ||||||
|  |                 " AND m.user_id = c.state_key" | ||||||
|  |                 " WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?" | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|             return [r["user_id"] for r in rows] |             txn.execute(sql, (room_id, Membership.JOIN,)) | ||||||
|  |             return [r[0] for r in txn] | ||||||
|         return self.runInteraction("get_users_in_room", f) |         return self.runInteraction("get_users_in_room", f) | ||||||
| 
 | 
 | ||||||
|     @cached() |     @cached() | ||||||
| @ -246,52 +259,27 @@ class RoomMemberStore(SQLBaseStore): | |||||||
| 
 | 
 | ||||||
|         return results |         return results | ||||||
| 
 | 
 | ||||||
|     def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None): |     @cachedInlineCallbacks(max_entries=500000, iterable=True) | ||||||
|         where_clause = "c.room_id = ?" |  | ||||||
|         where_values = [room_id] |  | ||||||
| 
 |  | ||||||
|         if membership: |  | ||||||
|             where_clause += " AND m.membership = ?" |  | ||||||
|             where_values.append(membership) |  | ||||||
| 
 |  | ||||||
|         if user_id: |  | ||||||
|             where_clause += " AND m.user_id = ?" |  | ||||||
|             where_values.append(user_id) |  | ||||||
| 
 |  | ||||||
|         sql = ( |  | ||||||
|             "SELECT m.* FROM room_memberships as m" |  | ||||||
|             " INNER JOIN current_state_events as c" |  | ||||||
|             " ON m.event_id = c.event_id " |  | ||||||
|             " AND m.room_id = c.room_id " |  | ||||||
|             " AND m.user_id = c.state_key" |  | ||||||
|             " WHERE c.type = 'm.room.member' AND %(where)s" |  | ||||||
|         ) % { |  | ||||||
|             "where": where_clause, |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         txn.execute(sql, where_values) |  | ||||||
|         rows = self.cursor_to_dict(txn) |  | ||||||
| 
 |  | ||||||
|         return rows |  | ||||||
| 
 |  | ||||||
|     @cached(max_entries=500000, iterable=True) |  | ||||||
|     def get_rooms_for_user(self, user_id): |     def get_rooms_for_user(self, user_id): | ||||||
|         return self.get_rooms_for_user_where_membership_is( |         """Returns a set of room_ids the user is currently joined to | ||||||
|  |         """ | ||||||
|  |         rooms = yield self.get_rooms_for_user_where_membership_is( | ||||||
|             user_id, membership_list=[Membership.JOIN], |             user_id, membership_list=[Membership.JOIN], | ||||||
|         ) |         ) | ||||||
|  |         defer.returnValue(frozenset(r.room_id for r in rooms)) | ||||||
| 
 | 
 | ||||||
|     @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True) |     @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True) | ||||||
|     def get_users_who_share_room_with_user(self, user_id, cache_context): |     def get_users_who_share_room_with_user(self, user_id, cache_context): | ||||||
|         """Returns the set of users who share a room with `user_id` |         """Returns the set of users who share a room with `user_id` | ||||||
|         """ |         """ | ||||||
|         rooms = yield self.get_rooms_for_user( |         room_ids = yield self.get_rooms_for_user( | ||||||
|             user_id, on_invalidate=cache_context.invalidate, |             user_id, on_invalidate=cache_context.invalidate, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         user_who_share_room = set() |         user_who_share_room = set() | ||||||
|         for room in rooms: |         for room_id in room_ids: | ||||||
|             user_ids = yield self.get_users_in_room( |             user_ids = yield self.get_users_in_room( | ||||||
|                 room.room_id, on_invalidate=cache_context.invalidate, |                 room_id, on_invalidate=cache_context.invalidate, | ||||||
|             ) |             ) | ||||||
|             user_who_share_room.update(user_ids) |             user_who_share_room.update(user_ids) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -72,7 +72,7 @@ class SignatureStore(SQLBaseStore): | |||||||
|             " WHERE event_id = ?" |             " WHERE event_id = ?" | ||||||
|         ) |         ) | ||||||
|         txn.execute(query, (event_id, )) |         txn.execute(query, (event_id, )) | ||||||
|         return {k: v for k, v in txn.fetchall()} |         return {k: v for k, v in txn} | ||||||
| 
 | 
 | ||||||
|     def _store_event_reference_hashes_txn(self, txn, events): |     def _store_event_reference_hashes_txn(self, txn, events): | ||||||
|         """Store a hash for a PDU |         """Store a hash for a PDU | ||||||
|  | |||||||
| @ -14,7 +14,7 @@ | |||||||
| # limitations under the License. | # limitations under the License. | ||||||
| 
 | 
 | ||||||
| from ._base import SQLBaseStore | from ._base import SQLBaseStore | ||||||
| from synapse.util.caches.descriptors import cached, cachedList | from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks | ||||||
| from synapse.util.caches import intern_string | from synapse.util.caches import intern_string | ||||||
| from synapse.storage.engines import PostgresEngine | from synapse.storage.engines import PostgresEngine | ||||||
| 
 | 
 | ||||||
| @ -69,6 +69,18 @@ class StateStore(SQLBaseStore): | |||||||
|             where_clause="type='m.room.member'", |             where_clause="type='m.room.member'", | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  |     @cachedInlineCallbacks(max_entries=100000, iterable=True) | ||||||
|  |     def get_current_state_ids(self, room_id): | ||||||
|  |         rows = yield self._simple_select_list( | ||||||
|  |             table="current_state_events", | ||||||
|  |             keyvalues={"room_id": room_id}, | ||||||
|  |             retcols=["event_id", "type", "state_key"], | ||||||
|  |             desc="_calculate_state_delta", | ||||||
|  |         ) | ||||||
|  |         defer.returnValue({ | ||||||
|  |             (r["type"], r["state_key"]): r["event_id"] for r in rows | ||||||
|  |         }) | ||||||
|  | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def get_state_groups_ids(self, room_id, event_ids): |     def get_state_groups_ids(self, room_id, event_ids): | ||||||
|         if not event_ids: |         if not event_ids: | ||||||
| @ -78,7 +90,7 @@ class StateStore(SQLBaseStore): | |||||||
|             event_ids, |             event_ids, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         groups = set(event_to_groups.values()) |         groups = set(event_to_groups.itervalues()) | ||||||
|         group_to_state = yield self._get_state_for_groups(groups) |         group_to_state = yield self._get_state_for_groups(groups) | ||||||
| 
 | 
 | ||||||
|         defer.returnValue(group_to_state) |         defer.returnValue(group_to_state) | ||||||
| @ -96,17 +108,18 @@ class StateStore(SQLBaseStore): | |||||||
| 
 | 
 | ||||||
|         state_event_map = yield self.get_events( |         state_event_map = yield self.get_events( | ||||||
|             [ |             [ | ||||||
|                 ev_id for group_ids in group_to_ids.values() |                 ev_id for group_ids in group_to_ids.itervalues() | ||||||
|                 for ev_id in group_ids.values() |                 for ev_id in group_ids.itervalues() | ||||||
|             ], |             ], | ||||||
|             get_prev_content=False |             get_prev_content=False | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         defer.returnValue({ |         defer.returnValue({ | ||||||
|             group: [ |             group: [ | ||||||
|                 state_event_map[v] for v in event_id_map.values() if v in state_event_map |                 state_event_map[v] for v in event_id_map.itervalues() | ||||||
|  |                 if v in state_event_map | ||||||
|             ] |             ] | ||||||
|             for group, event_id_map in group_to_ids.items() |             for group, event_id_map in group_to_ids.iteritems() | ||||||
|         }) |         }) | ||||||
| 
 | 
 | ||||||
|     def _have_persisted_state_group_txn(self, txn, state_group): |     def _have_persisted_state_group_txn(self, txn, state_group): | ||||||
| @ -124,6 +137,16 @@ class StateStore(SQLBaseStore): | |||||||
|                 continue |                 continue | ||||||
| 
 | 
 | ||||||
|             if context.current_state_ids is None: |             if context.current_state_ids is None: | ||||||
|  |                 # AFAIK, this can never happen | ||||||
|  |                 logger.error( | ||||||
|  |                     "Non-outlier event %s had current_state_ids==None", | ||||||
|  |                     event.event_id) | ||||||
|  |                 continue | ||||||
|  | 
 | ||||||
|  |             # if the event was rejected, just give it the same state as its | ||||||
|  |             # predecessor. | ||||||
|  |             if context.rejected: | ||||||
|  |                 state_groups[event.event_id] = context.prev_group | ||||||
|                 continue |                 continue | ||||||
| 
 | 
 | ||||||
|             state_groups[event.event_id] = context.state_group |             state_groups[event.event_id] = context.state_group | ||||||
| @ -168,7 +191,7 @@ class StateStore(SQLBaseStore): | |||||||
|                             "state_key": key[1], |                             "state_key": key[1], | ||||||
|                             "event_id": state_id, |                             "event_id": state_id, | ||||||
|                         } |                         } | ||||||
|                         for key, state_id in context.delta_ids.items() |                         for key, state_id in context.delta_ids.iteritems() | ||||||
|                     ], |                     ], | ||||||
|                 ) |                 ) | ||||||
|             else: |             else: | ||||||
| @ -183,7 +206,7 @@ class StateStore(SQLBaseStore): | |||||||
|                             "state_key": key[1], |                             "state_key": key[1], | ||||||
|                             "event_id": state_id, |                             "event_id": state_id, | ||||||
|                         } |                         } | ||||||
|                         for key, state_id in context.current_state_ids.items() |                         for key, state_id in context.current_state_ids.iteritems() | ||||||
|                     ], |                     ], | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
| @ -195,7 +218,7 @@ class StateStore(SQLBaseStore): | |||||||
|                     "state_group": state_group_id, |                     "state_group": state_group_id, | ||||||
|                     "event_id": event_id, |                     "event_id": event_id, | ||||||
|                 } |                 } | ||||||
|                 for event_id, state_group_id in state_groups.items() |                 for event_id, state_group_id in state_groups.iteritems() | ||||||
|             ], |             ], | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
| @ -319,10 +342,10 @@ class StateStore(SQLBaseStore): | |||||||
|                     args.extend(where_args) |                     args.extend(where_args) | ||||||
| 
 | 
 | ||||||
|                     txn.execute(sql % (where_clause,), args) |                     txn.execute(sql % (where_clause,), args) | ||||||
|                     rows = self.cursor_to_dict(txn) |                     for row in txn: | ||||||
|                     for row in rows: |                         typ, state_key, event_id = row | ||||||
|                         key = (row["type"], row["state_key"]) |                         key = (typ, state_key) | ||||||
|                         results[group][key] = row["event_id"] |                         results[group][key] = event_id | ||||||
|         else: |         else: | ||||||
|             if types is not None: |             if types is not None: | ||||||
|                 where_clause = "AND (%s)" % ( |                 where_clause = "AND (%s)" % ( | ||||||
| @ -351,12 +374,11 @@ class StateStore(SQLBaseStore): | |||||||
|                         " WHERE state_group = ? %s" % (where_clause,), |                         " WHERE state_group = ? %s" % (where_clause,), | ||||||
|                         args |                         args | ||||||
|                     ) |                     ) | ||||||
|                     rows = txn.fetchall() |                     results[group].update( | ||||||
|                     results[group].update({ |                         ((typ, state_key), event_id) | ||||||
|                         (typ, state_key): event_id |                         for typ, state_key, event_id in txn | ||||||
|                         for typ, state_key, event_id in rows |  | ||||||
|                         if (typ, state_key) not in results[group] |                         if (typ, state_key) not in results[group] | ||||||
|                     }) |                     ) | ||||||
| 
 | 
 | ||||||
|                     # If the lengths match then we must have all the types, |                     # If the lengths match then we must have all the types, | ||||||
|                     # so no need to go walk further down the tree. |                     # so no need to go walk further down the tree. | ||||||
| @ -393,21 +415,21 @@ class StateStore(SQLBaseStore): | |||||||
|             event_ids, |             event_ids, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         groups = set(event_to_groups.values()) |         groups = set(event_to_groups.itervalues()) | ||||||
|         group_to_state = yield self._get_state_for_groups(groups, types) |         group_to_state = yield self._get_state_for_groups(groups, types) | ||||||
| 
 | 
 | ||||||
|         state_event_map = yield self.get_events( |         state_event_map = yield self.get_events( | ||||||
|             [ev_id for sd in group_to_state.values() for ev_id in sd.values()], |             [ev_id for sd in group_to_state.itervalues() for ev_id in sd.itervalues()], | ||||||
|             get_prev_content=False |             get_prev_content=False | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         event_to_state = { |         event_to_state = { | ||||||
|             event_id: { |             event_id: { | ||||||
|                 k: state_event_map[v] |                 k: state_event_map[v] | ||||||
|                 for k, v in group_to_state[group].items() |                 for k, v in group_to_state[group].iteritems() | ||||||
|                 if v in state_event_map |                 if v in state_event_map | ||||||
|             } |             } | ||||||
|             for event_id, group in event_to_groups.items() |             for event_id, group in event_to_groups.iteritems() | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         defer.returnValue({event: event_to_state[event] for event in event_ids}) |         defer.returnValue({event: event_to_state[event] for event in event_ids}) | ||||||
| @ -430,12 +452,12 @@ class StateStore(SQLBaseStore): | |||||||
|             event_ids, |             event_ids, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         groups = set(event_to_groups.values()) |         groups = set(event_to_groups.itervalues()) | ||||||
|         group_to_state = yield self._get_state_for_groups(groups, types) |         group_to_state = yield self._get_state_for_groups(groups, types) | ||||||
| 
 | 
 | ||||||
|         event_to_state = { |         event_to_state = { | ||||||
|             event_id: group_to_state[group] |             event_id: group_to_state[group] | ||||||
|             for event_id, group in event_to_groups.items() |             for event_id, group in event_to_groups.iteritems() | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         defer.returnValue({event: event_to_state[event] for event in event_ids}) |         defer.returnValue({event: event_to_state[event] for event in event_ids}) | ||||||
| @ -474,7 +496,7 @@ class StateStore(SQLBaseStore): | |||||||
|         state_map = yield self.get_state_ids_for_events([event_id], types) |         state_map = yield self.get_state_ids_for_events([event_id], types) | ||||||
|         defer.returnValue(state_map[event_id]) |         defer.returnValue(state_map[event_id]) | ||||||
| 
 | 
 | ||||||
|     @cached(num_args=2, max_entries=10000) |     @cached(num_args=2, max_entries=100000) | ||||||
|     def _get_state_group_for_event(self, room_id, event_id): |     def _get_state_group_for_event(self, room_id, event_id): | ||||||
|         return self._simple_select_one_onecol( |         return self._simple_select_one_onecol( | ||||||
|             table="event_to_state_groups", |             table="event_to_state_groups", | ||||||
| @ -547,7 +569,7 @@ class StateStore(SQLBaseStore): | |||||||
|         got_all = not (missing_types or types is None) |         got_all = not (missing_types or types is None) | ||||||
| 
 | 
 | ||||||
|         return { |         return { | ||||||
|             k: v for k, v in state_dict_ids.items() |             k: v for k, v in state_dict_ids.iteritems() | ||||||
|             if include(k[0], k[1]) |             if include(k[0], k[1]) | ||||||
|         }, missing_types, got_all |         }, missing_types, got_all | ||||||
| 
 | 
 | ||||||
| @ -606,7 +628,7 @@ class StateStore(SQLBaseStore): | |||||||
| 
 | 
 | ||||||
|             # Now we want to update the cache with all the things we fetched |             # Now we want to update the cache with all the things we fetched | ||||||
|             # from the database. |             # from the database. | ||||||
|             for group, group_state_dict in group_to_state_dict.items(): |             for group, group_state_dict in group_to_state_dict.iteritems(): | ||||||
|                 if types: |                 if types: | ||||||
|                     # We delibrately put key -> None mappings into the cache to |                     # We delibrately put key -> None mappings into the cache to | ||||||
|                     # cache absence of the key, on the assumption that if we've |                     # cache absence of the key, on the assumption that if we've | ||||||
| @ -621,10 +643,10 @@ class StateStore(SQLBaseStore): | |||||||
|                 else: |                 else: | ||||||
|                     state_dict = results[group] |                     state_dict = results[group] | ||||||
| 
 | 
 | ||||||
|                 state_dict.update({ |                 state_dict.update( | ||||||
|                     (intern_string(k[0]), intern_string(k[1])): v |                     ((intern_string(k[0]), intern_string(k[1])), v) | ||||||
|                     for k, v in group_state_dict.items() |                     for k, v in group_state_dict.iteritems() | ||||||
|                 }) |                 ) | ||||||
| 
 | 
 | ||||||
|                 self._state_group_cache.update( |                 self._state_group_cache.update( | ||||||
|                     cache_seq_num, |                     cache_seq_num, | ||||||
| @ -635,10 +657,10 @@ class StateStore(SQLBaseStore): | |||||||
| 
 | 
 | ||||||
|         # Remove all the entries with None values. The None values were just |         # Remove all the entries with None values. The None values were just | ||||||
|         # used for bookkeeping in the cache. |         # used for bookkeeping in the cache. | ||||||
|         for group, state_dict in results.items(): |         for group, state_dict in results.iteritems(): | ||||||
|             results[group] = { |             results[group] = { | ||||||
|                 key: event_id |                 key: event_id | ||||||
|                 for key, event_id in state_dict.items() |                 for key, event_id in state_dict.iteritems() | ||||||
|                 if event_id |                 if event_id | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
| @ -727,7 +749,7 @@ class StateStore(SQLBaseStore): | |||||||
|                         # of keys |                         # of keys | ||||||
| 
 | 
 | ||||||
|                         delta_state = { |                         delta_state = { | ||||||
|                             key: value for key, value in curr_state.items() |                             key: value for key, value in curr_state.iteritems() | ||||||
|                             if prev_state.get(key, None) != value |                             if prev_state.get(key, None) != value | ||||||
|                         } |                         } | ||||||
| 
 | 
 | ||||||
| @ -767,7 +789,7 @@ class StateStore(SQLBaseStore): | |||||||
|                                     "state_key": key[1], |                                     "state_key": key[1], | ||||||
|                                     "event_id": state_id, |                                     "event_id": state_id, | ||||||
|                                 } |                                 } | ||||||
|                                 for key, state_id in delta_state.items() |                                 for key, state_id in delta_state.iteritems() | ||||||
|                             ], |                             ], | ||||||
|                         ) |                         ) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -829,3 +829,6 @@ class StreamStore(SQLBaseStore): | |||||||
|             updatevalues={"stream_id": stream_id}, |             updatevalues={"stream_id": stream_id}, | ||||||
|             desc="update_federation_out_pos", |             desc="update_federation_out_pos", | ||||||
|         ) |         ) | ||||||
|  | 
 | ||||||
|  |     def has_room_changed_since(self, room_id, stream_id): | ||||||
|  |         return self._events_stream_cache.has_entity_changed(room_id, stream_id) | ||||||
|  | |||||||
| @ -95,7 +95,7 @@ class TagsStore(SQLBaseStore): | |||||||
|             for stream_id, user_id, room_id in tag_ids: |             for stream_id, user_id, room_id in tag_ids: | ||||||
|                 txn.execute(sql, (user_id, room_id)) |                 txn.execute(sql, (user_id, room_id)) | ||||||
|                 tags = [] |                 tags = [] | ||||||
|                 for tag, content in txn.fetchall(): |                 for tag, content in txn: | ||||||
|                     tags.append(json.dumps(tag) + ":" + content) |                     tags.append(json.dumps(tag) + ":" + content) | ||||||
|                 tag_json = "{" + ",".join(tags) + "}" |                 tag_json = "{" + ",".join(tags) + "}" | ||||||
|                 results.append((stream_id, user_id, room_id, tag_json)) |                 results.append((stream_id, user_id, room_id, tag_json)) | ||||||
| @ -132,7 +132,7 @@ class TagsStore(SQLBaseStore): | |||||||
|                 " WHERE user_id = ? AND stream_id > ?" |                 " WHERE user_id = ? AND stream_id > ?" | ||||||
|             ) |             ) | ||||||
|             txn.execute(sql, (user_id, stream_id)) |             txn.execute(sql, (user_id, stream_id)) | ||||||
|             room_ids = [row[0] for row in txn.fetchall()] |             room_ids = [row[0] for row in txn] | ||||||
|             return room_ids |             return room_ids | ||||||
| 
 | 
 | ||||||
|         changed = self._account_data_stream_cache.has_entity_changed( |         changed = self._account_data_stream_cache.has_entity_changed( | ||||||
|  | |||||||
| @ -30,6 +30,17 @@ class IdGenerator(object): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def _load_current_id(db_conn, table, column, step=1): | def _load_current_id(db_conn, table, column, step=1): | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         db_conn (object): | ||||||
|  |         table (str): | ||||||
|  |         column (str): | ||||||
|  |         step (int): | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         int | ||||||
|  |     """ | ||||||
|     cur = db_conn.cursor() |     cur = db_conn.cursor() | ||||||
|     if step == 1: |     if step == 1: | ||||||
|         cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) |         cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) | ||||||
| @ -131,6 +142,9 @@ class StreamIdGenerator(object): | |||||||
|     def get_current_token(self): |     def get_current_token(self): | ||||||
|         """Returns the maximum stream id such that all stream ids less than or |         """Returns the maximum stream id such that all stream ids less than or | ||||||
|         equal to it have been successfully persisted. |         equal to it have been successfully persisted. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             int | ||||||
|         """ |         """ | ||||||
|         with self._lock: |         with self._lock: | ||||||
|             if self._unfinished_ids: |             if self._unfinished_ids: | ||||||
|  | |||||||
| @ -26,7 +26,7 @@ logger = logging.getLogger(__name__) | |||||||
| 
 | 
 | ||||||
| class DeferredTimedOutError(SynapseError): | class DeferredTimedOutError(SynapseError): | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         super(SynapseError).__init__(504, "Timed out") |         super(SynapseError, self).__init__(504, "Timed out") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def unwrapFirstError(failure): | def unwrapFirstError(failure): | ||||||
| @ -93,8 +93,10 @@ class Clock(object): | |||||||
|         ret_deferred = defer.Deferred() |         ret_deferred = defer.Deferred() | ||||||
| 
 | 
 | ||||||
|         def timed_out_fn(): |         def timed_out_fn(): | ||||||
|  |             e = DeferredTimedOutError() | ||||||
|  | 
 | ||||||
|             try: |             try: | ||||||
|                 ret_deferred.errback(DeferredTimedOutError()) |                 ret_deferred.errback(e) | ||||||
|             except: |             except: | ||||||
|                 pass |                 pass | ||||||
| 
 | 
 | ||||||
| @ -114,7 +116,7 @@ class Clock(object): | |||||||
| 
 | 
 | ||||||
|         ret_deferred.addBoth(cancel) |         ret_deferred.addBoth(cancel) | ||||||
| 
 | 
 | ||||||
|         def sucess(res): |         def success(res): | ||||||
|             try: |             try: | ||||||
|                 ret_deferred.callback(res) |                 ret_deferred.callback(res) | ||||||
|             except: |             except: | ||||||
| @ -128,7 +130,7 @@ class Clock(object): | |||||||
|             except: |             except: | ||||||
|                 pass |                 pass | ||||||
| 
 | 
 | ||||||
|         given_deferred.addCallbacks(callback=sucess, errback=err) |         given_deferred.addCallbacks(callback=success, errback=err) | ||||||
| 
 | 
 | ||||||
|         timer = self.call_later(time_out, timed_out_fn) |         timer = self.call_later(time_out, timed_out_fn) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -15,12 +15,9 @@ | |||||||
| import logging | import logging | ||||||
| 
 | 
 | ||||||
| from synapse.util.async import ObservableDeferred | from synapse.util.async import ObservableDeferred | ||||||
| from synapse.util import unwrapFirstError | from synapse.util import unwrapFirstError, logcontext | ||||||
| from synapse.util.caches.lrucache import LruCache | from synapse.util.caches.lrucache import LruCache | ||||||
| from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry | from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry | ||||||
| from synapse.util.logcontext import ( |  | ||||||
|     PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn |  | ||||||
| ) |  | ||||||
| 
 | 
 | ||||||
| from . import DEBUG_CACHES, register_cache | from . import DEBUG_CACHES, register_cache | ||||||
| 
 | 
 | ||||||
| @ -189,7 +186,55 @@ class Cache(object): | |||||||
|         self.cache.clear() |         self.cache.clear() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class CacheDescriptor(object): | class _CacheDescriptorBase(object): | ||||||
|  |     def __init__(self, orig, num_args, inlineCallbacks, cache_context=False): | ||||||
|  |         self.orig = orig | ||||||
|  | 
 | ||||||
|  |         if inlineCallbacks: | ||||||
|  |             self.function_to_call = defer.inlineCallbacks(orig) | ||||||
|  |         else: | ||||||
|  |             self.function_to_call = orig | ||||||
|  | 
 | ||||||
|  |         arg_spec = inspect.getargspec(orig) | ||||||
|  |         all_args = arg_spec.args | ||||||
|  | 
 | ||||||
|  |         if "cache_context" in all_args: | ||||||
|  |             if not cache_context: | ||||||
|  |                 raise ValueError( | ||||||
|  |                     "Cannot have a 'cache_context' arg without setting" | ||||||
|  |                     " cache_context=True" | ||||||
|  |                 ) | ||||||
|  |         elif cache_context: | ||||||
|  |             raise ValueError( | ||||||
|  |                 "Cannot have cache_context=True without having an arg" | ||||||
|  |                 " named `cache_context`" | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         if num_args is None: | ||||||
|  |             num_args = len(all_args) - 1 | ||||||
|  |             if cache_context: | ||||||
|  |                 num_args -= 1 | ||||||
|  | 
 | ||||||
|  |         if len(all_args) < num_args + 1: | ||||||
|  |             raise Exception( | ||||||
|  |                 "Not enough explicit positional arguments to key off for %r: " | ||||||
|  |                 "got %i args, but wanted %i. (@cached cannot key off *args or " | ||||||
|  |                 "**kwargs)" | ||||||
|  |                 % (orig.__name__, len(all_args), num_args) | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         self.num_args = num_args | ||||||
|  |         self.arg_names = all_args[1:num_args + 1] | ||||||
|  | 
 | ||||||
|  |         if "cache_context" in self.arg_names: | ||||||
|  |             raise Exception( | ||||||
|  |                 "cache_context arg cannot be included among the cache keys" | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         self.add_cache_context = cache_context | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class CacheDescriptor(_CacheDescriptorBase): | ||||||
|     """ A method decorator that applies a memoizing cache around the function. |     """ A method decorator that applies a memoizing cache around the function. | ||||||
| 
 | 
 | ||||||
|     This caches deferreds, rather than the results themselves. Deferreds that |     This caches deferreds, rather than the results themselves. Deferreds that | ||||||
| @ -217,52 +262,24 @@ class CacheDescriptor(object): | |||||||
|             r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate) |             r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate) | ||||||
|             defer.returnValue(r1 + r2) |             defer.returnValue(r1 + r2) | ||||||
| 
 | 
 | ||||||
|  |     Args: | ||||||
|  |         num_args (int): number of positional arguments (excluding ``self`` and | ||||||
|  |             ``cache_context``) to use as cache keys. Defaults to all named | ||||||
|  |             args of the function. | ||||||
|     """ |     """ | ||||||
|     def __init__(self, orig, max_entries=1000, num_args=1, tree=False, |     def __init__(self, orig, max_entries=1000, num_args=None, tree=False, | ||||||
|                  inlineCallbacks=False, cache_context=False, iterable=False): |                  inlineCallbacks=False, cache_context=False, iterable=False): | ||||||
|  | 
 | ||||||
|  |         super(CacheDescriptor, self).__init__( | ||||||
|  |             orig, num_args=num_args, inlineCallbacks=inlineCallbacks, | ||||||
|  |             cache_context=cache_context) | ||||||
|  | 
 | ||||||
|         max_entries = int(max_entries * CACHE_SIZE_FACTOR) |         max_entries = int(max_entries * CACHE_SIZE_FACTOR) | ||||||
| 
 | 
 | ||||||
|         self.orig = orig |  | ||||||
| 
 |  | ||||||
|         if inlineCallbacks: |  | ||||||
|             self.function_to_call = defer.inlineCallbacks(orig) |  | ||||||
|         else: |  | ||||||
|             self.function_to_call = orig |  | ||||||
| 
 |  | ||||||
|         self.max_entries = max_entries |         self.max_entries = max_entries | ||||||
|         self.num_args = num_args |  | ||||||
|         self.tree = tree |         self.tree = tree | ||||||
| 
 |  | ||||||
|         self.iterable = iterable |         self.iterable = iterable | ||||||
| 
 | 
 | ||||||
|         all_args = inspect.getargspec(orig) |  | ||||||
|         self.arg_names = all_args.args[1:num_args + 1] |  | ||||||
| 
 |  | ||||||
|         if "cache_context" in all_args.args: |  | ||||||
|             if not cache_context: |  | ||||||
|                 raise ValueError( |  | ||||||
|                     "Cannot have a 'cache_context' arg without setting" |  | ||||||
|                     " cache_context=True" |  | ||||||
|                 ) |  | ||||||
|             try: |  | ||||||
|                 self.arg_names.remove("cache_context") |  | ||||||
|             except ValueError: |  | ||||||
|                 pass |  | ||||||
|         elif cache_context: |  | ||||||
|             raise ValueError( |  | ||||||
|                 "Cannot have cache_context=True without having an arg" |  | ||||||
|                 " named `cache_context`" |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|         self.add_cache_context = cache_context |  | ||||||
| 
 |  | ||||||
|         if len(self.arg_names) < self.num_args: |  | ||||||
|             raise Exception( |  | ||||||
|                 "Not enough explicit positional arguments to key off of for %r." |  | ||||||
|                 " (@cached cannot key off of *args or **kwargs)" |  | ||||||
|                 % (orig.__name__,) |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|     def __get__(self, obj, objtype=None): |     def __get__(self, obj, objtype=None): | ||||||
|         cache = Cache( |         cache = Cache( | ||||||
|             name=self.orig.__name__, |             name=self.orig.__name__, | ||||||
| @ -308,11 +325,9 @@ class CacheDescriptor(object): | |||||||
|                         defer.returnValue(cached_result) |                         defer.returnValue(cached_result) | ||||||
|                     observer.addCallback(check_result) |                     observer.addCallback(check_result) | ||||||
| 
 | 
 | ||||||
|                 return preserve_context_over_deferred(observer) |  | ||||||
|             except KeyError: |             except KeyError: | ||||||
|                 ret = defer.maybeDeferred( |                 ret = defer.maybeDeferred( | ||||||
|                     preserve_context_over_fn, |                     logcontext.preserve_fn(self.function_to_call), | ||||||
|                     self.function_to_call, |  | ||||||
|                     obj, *args, **kwargs |                     obj, *args, **kwargs | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
| @ -322,10 +337,11 @@ class CacheDescriptor(object): | |||||||
| 
 | 
 | ||||||
|                 ret.addErrback(onErr) |                 ret.addErrback(onErr) | ||||||
| 
 | 
 | ||||||
|                 ret = ObservableDeferred(ret, consumeErrors=True) |                 result_d = ObservableDeferred(ret, consumeErrors=True) | ||||||
|                 cache.set(cache_key, ret, callback=invalidate_callback) |                 cache.set(cache_key, result_d, callback=invalidate_callback) | ||||||
|  |                 observer = result_d.observe() | ||||||
| 
 | 
 | ||||||
|                 return preserve_context_over_deferred(ret.observe()) |             return logcontext.make_deferred_yieldable(observer) | ||||||
| 
 | 
 | ||||||
|         wrapped.invalidate = cache.invalidate |         wrapped.invalidate = cache.invalidate | ||||||
|         wrapped.invalidate_all = cache.invalidate_all |         wrapped.invalidate_all = cache.invalidate_all | ||||||
| @ -338,48 +354,40 @@ class CacheDescriptor(object): | |||||||
|         return wrapped |         return wrapped | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class CacheListDescriptor(object): | class CacheListDescriptor(_CacheDescriptorBase): | ||||||
|     """Wraps an existing cache to support bulk fetching of keys. |     """Wraps an existing cache to support bulk fetching of keys. | ||||||
| 
 | 
 | ||||||
|     Given a list of keys it looks in the cache to find any hits, then passes |     Given a list of keys it looks in the cache to find any hits, then passes | ||||||
|     the list of missing keys to the wrapped fucntion. |     the list of missing keys to the wrapped function. | ||||||
|  | 
 | ||||||
|  |     Once wrapped, the function returns either a Deferred which resolves to | ||||||
|  |     the list of results, or (if all results were cached), just the list of | ||||||
|  |     results. | ||||||
|     """ |     """ | ||||||
| 
 | 
 | ||||||
|     def __init__(self, orig, cached_method_name, list_name, num_args=1, |     def __init__(self, orig, cached_method_name, list_name, num_args=None, | ||||||
|                  inlineCallbacks=False): |                  inlineCallbacks=False): | ||||||
|         """ |         """ | ||||||
|         Args: |         Args: | ||||||
|             orig (function) |             orig (function) | ||||||
|             method_name (str); The name of the chached method. |             cached_method_name (str): The name of the chached method. | ||||||
|             list_name (str): Name of the argument which is the bulk lookup list |             list_name (str): Name of the argument which is the bulk lookup list | ||||||
|             num_args (int) |             num_args (int): number of positional arguments (excluding ``self``, | ||||||
|  |                 but including list_name) to use as cache keys. Defaults to all | ||||||
|  |                 named args of the function. | ||||||
|             inlineCallbacks (bool): Whether orig is a generator that should |             inlineCallbacks (bool): Whether orig is a generator that should | ||||||
|                 be wrapped by defer.inlineCallbacks |                 be wrapped by defer.inlineCallbacks | ||||||
|         """ |         """ | ||||||
|         self.orig = orig |         super(CacheListDescriptor, self).__init__( | ||||||
|  |             orig, num_args=num_args, inlineCallbacks=inlineCallbacks) | ||||||
| 
 | 
 | ||||||
|         if inlineCallbacks: |  | ||||||
|             self.function_to_call = defer.inlineCallbacks(orig) |  | ||||||
|         else: |  | ||||||
|             self.function_to_call = orig |  | ||||||
| 
 |  | ||||||
|         self.num_args = num_args |  | ||||||
|         self.list_name = list_name |         self.list_name = list_name | ||||||
| 
 | 
 | ||||||
|         self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] |  | ||||||
|         self.list_pos = self.arg_names.index(self.list_name) |         self.list_pos = self.arg_names.index(self.list_name) | ||||||
| 
 |  | ||||||
|         self.cached_method_name = cached_method_name |         self.cached_method_name = cached_method_name | ||||||
| 
 | 
 | ||||||
|         self.sentinel = object() |         self.sentinel = object() | ||||||
| 
 | 
 | ||||||
|         if len(self.arg_names) < self.num_args: |  | ||||||
|             raise Exception( |  | ||||||
|                 "Not enough explicit positional arguments to key off of for %r." |  | ||||||
|                 " (@cached cannot key off of *args or **kwars)" |  | ||||||
|                 % (orig.__name__,) |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|         if self.list_name not in self.arg_names: |         if self.list_name not in self.arg_names: | ||||||
|             raise Exception( |             raise Exception( | ||||||
|                 "Couldn't see arguments %r for %r." |                 "Couldn't see arguments %r for %r." | ||||||
| @ -425,8 +433,7 @@ class CacheListDescriptor(object): | |||||||
|                 args_to_call[self.list_name] = missing |                 args_to_call[self.list_name] = missing | ||||||
| 
 | 
 | ||||||
|                 ret_d = defer.maybeDeferred( |                 ret_d = defer.maybeDeferred( | ||||||
|                     preserve_context_over_fn, |                     logcontext.preserve_fn(self.function_to_call), | ||||||
|                     self.function_to_call, |  | ||||||
|                     **args_to_call |                     **args_to_call | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
| @ -435,7 +442,6 @@ class CacheListDescriptor(object): | |||||||
|                 # We need to create deferreds for each arg in the list so that |                 # We need to create deferreds for each arg in the list so that | ||||||
|                 # we can insert the new deferred into the cache. |                 # we can insert the new deferred into the cache. | ||||||
|                 for arg in missing: |                 for arg in missing: | ||||||
|                     with PreserveLoggingContext(): |  | ||||||
|                     observer = ret_d.observe() |                     observer = ret_d.observe() | ||||||
|                     observer.addCallback(lambda r, arg: r.get(arg, None), arg) |                     observer.addCallback(lambda r, arg: r.get(arg, None), arg) | ||||||
| 
 | 
 | ||||||
| @ -463,7 +469,7 @@ class CacheListDescriptor(object): | |||||||
|                     results.update(res) |                     results.update(res) | ||||||
|                     return results |                     return results | ||||||
| 
 | 
 | ||||||
|                 return preserve_context_over_deferred(defer.gatherResults( |                 return logcontext.make_deferred_yieldable(defer.gatherResults( | ||||||
|                     cached_defers.values(), |                     cached_defers.values(), | ||||||
|                     consumeErrors=True, |                     consumeErrors=True, | ||||||
|                 ).addCallback(update_results_dict).addErrback( |                 ).addCallback(update_results_dict).addErrback( | ||||||
| @ -487,7 +493,7 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): | |||||||
|         self.cache.invalidate(self.key) |         self.cache.invalidate(self.key) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def cached(max_entries=1000, num_args=1, tree=False, cache_context=False, | def cached(max_entries=1000, num_args=None, tree=False, cache_context=False, | ||||||
|            iterable=False): |            iterable=False): | ||||||
|     return lambda orig: CacheDescriptor( |     return lambda orig: CacheDescriptor( | ||||||
|         orig, |         orig, | ||||||
| @ -499,8 +505,8 @@ def cached(max_entries=1000, num_args=1, tree=False, cache_context=False, | |||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False, | def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False, | ||||||
|                           iterable=False): |                           cache_context=False, iterable=False): | ||||||
|     return lambda orig: CacheDescriptor( |     return lambda orig: CacheDescriptor( | ||||||
|         orig, |         orig, | ||||||
|         max_entries=max_entries, |         max_entries=max_entries, | ||||||
| @ -512,7 +518,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex | |||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False): | def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=False): | ||||||
|     """Creates a descriptor that wraps a function in a `CacheListDescriptor`. |     """Creates a descriptor that wraps a function in a `CacheListDescriptor`. | ||||||
| 
 | 
 | ||||||
|     Used to do batch lookups for an already created cache. A single argument |     Used to do batch lookups for an already created cache. A single argument | ||||||
| @ -525,7 +531,8 @@ def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False) | |||||||
|         cache (Cache): The underlying cache to use. |         cache (Cache): The underlying cache to use. | ||||||
|         list_name (str): The name of the argument that is the list to use to |         list_name (str): The name of the argument that is the list to use to | ||||||
|             do batch lookups in the cache. |             do batch lookups in the cache. | ||||||
|         num_args (int): Number of arguments to use as the key in the cache. |         num_args (int): Number of arguments to use as the key in the cache | ||||||
|  |             (including list_name). Defaults to all named parameters. | ||||||
|         inlineCallbacks (bool): Should the function be wrapped in an |         inlineCallbacks (bool): Should the function be wrapped in an | ||||||
|             `defer.inlineCallbacks`? |             `defer.inlineCallbacks`? | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -50,7 +50,7 @@ class StreamChangeCache(object): | |||||||
|     def has_entity_changed(self, entity, stream_pos): |     def has_entity_changed(self, entity, stream_pos): | ||||||
|         """Returns True if the entity may have been updated since stream_pos |         """Returns True if the entity may have been updated since stream_pos | ||||||
|         """ |         """ | ||||||
|         assert type(stream_pos) is int |         assert type(stream_pos) is int or type(stream_pos) is long | ||||||
| 
 | 
 | ||||||
|         if stream_pos < self._earliest_known_stream_pos: |         if stream_pos < self._earliest_known_stream_pos: | ||||||
|             self.metrics.inc_misses() |             self.metrics.inc_misses() | ||||||
|  | |||||||
| @ -12,6 +12,16 @@ | |||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| 
 | 
 | ||||||
|  | """ Thread-local-alike tracking of log contexts within synapse | ||||||
|  | 
 | ||||||
|  | This module provides objects and utilities for tracking contexts through | ||||||
|  | synapse code, so that log lines can include a request identifier, and so that | ||||||
|  | CPU and database activity can be accounted for against the request that caused | ||||||
|  | them. | ||||||
|  | 
 | ||||||
|  | See doc/log_contexts.rst for details on how this works. | ||||||
|  | """ | ||||||
|  | 
 | ||||||
| from twisted.internet import defer | from twisted.internet import defer | ||||||
| 
 | 
 | ||||||
| import threading | import threading | ||||||
| @ -300,6 +310,10 @@ def preserve_context_over_fn(fn, *args, **kwargs): | |||||||
| def preserve_context_over_deferred(deferred, context=None): | def preserve_context_over_deferred(deferred, context=None): | ||||||
|     """Given a deferred wrap it such that any callbacks added later to it will |     """Given a deferred wrap it such that any callbacks added later to it will | ||||||
|     be invoked with the current context. |     be invoked with the current context. | ||||||
|  | 
 | ||||||
|  |     Deprecated: this almost certainly doesn't do want you want, ie make | ||||||
|  |     the deferred follow the synapse logcontext rules: try | ||||||
|  |     ``make_deferred_yieldable`` instead. | ||||||
|     """ |     """ | ||||||
|     if context is None: |     if context is None: | ||||||
|         context = LoggingContext.current_context() |         context = LoggingContext.current_context() | ||||||
| @ -309,24 +323,65 @@ def preserve_context_over_deferred(deferred, context=None): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def preserve_fn(f): | def preserve_fn(f): | ||||||
|     """Ensures that function is called with correct context and that context is |     """Wraps a function, to ensure that the current context is restored after | ||||||
|     restored after return. Useful for wrapping functions that return a deferred |     return from the function, and that the sentinel context is set once the | ||||||
|     which you don't yield on. |     deferred returned by the funtion completes. | ||||||
|  | 
 | ||||||
|  |     Useful for wrapping functions that return a deferred which you don't yield | ||||||
|  |     on. | ||||||
|     """ |     """ | ||||||
|  |     def reset_context(result): | ||||||
|  |         LoggingContext.set_current_context(LoggingContext.sentinel) | ||||||
|  |         return result | ||||||
|  | 
 | ||||||
|  |     # XXX: why is this here rather than inside g? surely we want to preserve | ||||||
|  |     # the context from the time the function was called, not when it was | ||||||
|  |     # wrapped? | ||||||
|     current = LoggingContext.current_context() |     current = LoggingContext.current_context() | ||||||
| 
 | 
 | ||||||
|     def g(*args, **kwargs): |     def g(*args, **kwargs): | ||||||
|         with PreserveLoggingContext(current): |  | ||||||
|         res = f(*args, **kwargs) |         res = f(*args, **kwargs) | ||||||
|             if isinstance(res, defer.Deferred): |         if isinstance(res, defer.Deferred) and not res.called: | ||||||
|                 return preserve_context_over_deferred( |             # The function will have reset the context before returning, so | ||||||
|                     res, context=LoggingContext.sentinel |             # we need to restore it now. | ||||||
|                 ) |             LoggingContext.set_current_context(current) | ||||||
|             else: | 
 | ||||||
|  |             # The original context will be restored when the deferred | ||||||
|  |             # completes, but there is nothing waiting for it, so it will | ||||||
|  |             # get leaked into the reactor or some other function which | ||||||
|  |             # wasn't expecting it. We therefore need to reset the context | ||||||
|  |             # here. | ||||||
|  |             # | ||||||
|  |             # (If this feels asymmetric, consider it this way: we are | ||||||
|  |             # effectively forking a new thread of execution. We are | ||||||
|  |             # probably currently within a ``with LoggingContext()`` block, | ||||||
|  |             # which is supposed to have a single entry and exit point. But | ||||||
|  |             # by spawning off another deferred, we are effectively | ||||||
|  |             # adding a new exit point.) | ||||||
|  |             res.addBoth(reset_context) | ||||||
|         return res |         return res | ||||||
|     return g |     return g | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @defer.inlineCallbacks | ||||||
|  | def make_deferred_yieldable(deferred): | ||||||
|  |     """Given a deferred, make it follow the Synapse logcontext rules: | ||||||
|  | 
 | ||||||
|  |     If the deferred has completed (or is not actually a Deferred), essentially | ||||||
|  |     does nothing (just returns another completed deferred with the | ||||||
|  |     result/failure). | ||||||
|  | 
 | ||||||
|  |     If the deferred has not yet completed, resets the logcontext before | ||||||
|  |     returning a deferred. Then, when the deferred completes, restores the | ||||||
|  |     current logcontext before running callbacks/errbacks. | ||||||
|  | 
 | ||||||
|  |     (This is more-or-less the opposite operation to preserve_fn.) | ||||||
|  |     """ | ||||||
|  |     with PreserveLoggingContext(): | ||||||
|  |         r = yield deferred | ||||||
|  |     defer.returnValue(r) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| # modules to ignore in `logcontext_tracer` | # modules to ignore in `logcontext_tracer` | ||||||
| _to_ignore = [ | _to_ignore = [ | ||||||
|     "synapse.util.logcontext", |     "synapse.util.logcontext", | ||||||
|  | |||||||
							
								
								
									
										40
									
								
								synapse/util/msisdn.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								synapse/util/msisdn.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,40 @@ | |||||||
|  | # -*- coding: utf-8 -*- | ||||||
|  | # Copyright 2017 Vector Creations Ltd | ||||||
|  | # | ||||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | # you may not use this file except in compliance with the License. | ||||||
|  | # You may obtain a copy of the License at | ||||||
|  | # | ||||||
|  | #     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | # | ||||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | # See the License for the specific language governing permissions and | ||||||
|  | # limitations under the License. | ||||||
|  | 
 | ||||||
|  | import phonenumbers | ||||||
|  | from synapse.api.errors import SynapseError | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def phone_number_to_msisdn(country, number): | ||||||
|  |     """ | ||||||
|  |     Takes an ISO-3166-1 2 letter country code and phone number and | ||||||
|  |     returns an msisdn representing the canonical version of that | ||||||
|  |     phone number. | ||||||
|  |     Args: | ||||||
|  |         country (str): ISO-3166-1 2 letter country code | ||||||
|  |         number (str): Phone number in a national or international format | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         (str) The canonical form of the phone number, as an msisdn | ||||||
|  |     Raises: | ||||||
|  |             SynapseError if the number could not be parsed. | ||||||
|  |     """ | ||||||
|  |     try: | ||||||
|  |         phoneNumber = phonenumbers.parse(number, country) | ||||||
|  |     except phonenumbers.NumberParseException: | ||||||
|  |         raise SynapseError(400, "Unable to parse phone number") | ||||||
|  |     return phonenumbers.format_number( | ||||||
|  |         phoneNumber, phonenumbers.PhoneNumberFormat.E164 | ||||||
|  |     )[1:] | ||||||
| @ -12,7 +12,7 @@ | |||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| 
 | import synapse.util.logcontext | ||||||
| from twisted.internet import defer | from twisted.internet import defer | ||||||
| 
 | 
 | ||||||
| from synapse.api.errors import CodeMessageException | from synapse.api.errors import CodeMessageException | ||||||
| @ -35,7 +35,8 @@ class NotRetryingDestination(Exception): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @defer.inlineCallbacks | @defer.inlineCallbacks | ||||||
| def get_retry_limiter(destination, clock, store, **kwargs): | def get_retry_limiter(destination, clock, store, ignore_backoff=False, | ||||||
|  |                       **kwargs): | ||||||
|     """For a given destination check if we have previously failed to |     """For a given destination check if we have previously failed to | ||||||
|     send a request there and are waiting before retrying the destination. |     send a request there and are waiting before retrying the destination. | ||||||
|     If we are not ready to retry the destination, this will raise a |     If we are not ready to retry the destination, this will raise a | ||||||
| @ -43,6 +44,14 @@ def get_retry_limiter(destination, clock, store, **kwargs): | |||||||
|     that will mark the destination as down if an exception is thrown (excluding |     that will mark the destination as down if an exception is thrown (excluding | ||||||
|     CodeMessageException with code < 500) |     CodeMessageException with code < 500) | ||||||
| 
 | 
 | ||||||
|  |     Args: | ||||||
|  |         destination (str): name of homeserver | ||||||
|  |         clock (synapse.util.clock): timing source | ||||||
|  |         store (synapse.storage.transactions.TransactionStore): datastore | ||||||
|  |         ignore_backoff (bool): true to ignore the historical backoff data and | ||||||
|  |             try the request anyway. We will still update the next | ||||||
|  |             retry_interval on success/failure. | ||||||
|  | 
 | ||||||
|     Example usage: |     Example usage: | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
| @ -66,7 +75,7 @@ def get_retry_limiter(destination, clock, store, **kwargs): | |||||||
| 
 | 
 | ||||||
|         now = int(clock.time_msec()) |         now = int(clock.time_msec()) | ||||||
| 
 | 
 | ||||||
|         if retry_last_ts + retry_interval > now: |         if not ignore_backoff and retry_last_ts + retry_interval > now: | ||||||
|             raise NotRetryingDestination( |             raise NotRetryingDestination( | ||||||
|                 retry_last_ts=retry_last_ts, |                 retry_last_ts=retry_last_ts, | ||||||
|                 retry_interval=retry_interval, |                 retry_interval=retry_interval, | ||||||
| @ -124,7 +133,13 @@ class RetryDestinationLimiter(object): | |||||||
| 
 | 
 | ||||||
|     def __exit__(self, exc_type, exc_val, exc_tb): |     def __exit__(self, exc_type, exc_val, exc_tb): | ||||||
|         valid_err_code = False |         valid_err_code = False | ||||||
|         if exc_type is not None and issubclass(exc_type, CodeMessageException): |         if exc_type is None: | ||||||
|  |             valid_err_code = True | ||||||
|  |         elif not issubclass(exc_type, Exception): | ||||||
|  |             # avoid treating exceptions which don't derive from Exception as | ||||||
|  |             # failures; this is mostly so as not to catch defer._DefGen. | ||||||
|  |             valid_err_code = True | ||||||
|  |         elif issubclass(exc_type, CodeMessageException): | ||||||
|             # Some error codes are perfectly fine for some APIs, whereas other |             # Some error codes are perfectly fine for some APIs, whereas other | ||||||
|             # APIs may expect to never received e.g. a 404. It's important to |             # APIs may expect to never received e.g. a 404. It's important to | ||||||
|             # handle 404 as some remote servers will return a 404 when the HS |             # handle 404 as some remote servers will return a 404 when the HS | ||||||
| @ -142,11 +157,13 @@ class RetryDestinationLimiter(object): | |||||||
|             else: |             else: | ||||||
|                 valid_err_code = False |                 valid_err_code = False | ||||||
| 
 | 
 | ||||||
|         if exc_type is None or valid_err_code: |         if valid_err_code: | ||||||
|             # We connected successfully. |             # We connected successfully. | ||||||
|             if not self.retry_interval: |             if not self.retry_interval: | ||||||
|                 return |                 return | ||||||
| 
 | 
 | ||||||
|  |             logger.debug("Connection to %s was successful; clearing backoff", | ||||||
|  |                          self.destination) | ||||||
|             retry_last_ts = 0 |             retry_last_ts = 0 | ||||||
|             self.retry_interval = 0 |             self.retry_interval = 0 | ||||||
|         else: |         else: | ||||||
| @ -160,6 +177,10 @@ class RetryDestinationLimiter(object): | |||||||
|             else: |             else: | ||||||
|                 self.retry_interval = self.min_retry_interval |                 self.retry_interval = self.min_retry_interval | ||||||
| 
 | 
 | ||||||
|  |             logger.debug( | ||||||
|  |                 "Connection to %s was unsuccessful (%s(%s)); backoff now %i", | ||||||
|  |                 self.destination, exc_type, exc_val, self.retry_interval | ||||||
|  |             ) | ||||||
|             retry_last_ts = int(self.clock.time_msec()) |             retry_last_ts = int(self.clock.time_msec()) | ||||||
| 
 | 
 | ||||||
|         @defer.inlineCallbacks |         @defer.inlineCallbacks | ||||||
| @ -173,4 +194,5 @@ class RetryDestinationLimiter(object): | |||||||
|                     "Failed to store set_destination_retry_timings", |                     "Failed to store set_destination_retry_timings", | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|         store_retry_timings() |         # we deliberately do this in the background. | ||||||
|  |         synapse.util.logcontext.preserve_fn(store_retry_timings)() | ||||||
|  | |||||||
| @ -134,6 +134,13 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state): | |||||||
|             if prev_membership not in MEMBERSHIP_PRIORITY: |             if prev_membership not in MEMBERSHIP_PRIORITY: | ||||||
|                 prev_membership = "leave" |                 prev_membership = "leave" | ||||||
| 
 | 
 | ||||||
|  |             # Always allow the user to see their own leave events, otherwise | ||||||
|  |             # they won't see the room disappear if they reject the invite | ||||||
|  |             if membership == "leave" and ( | ||||||
|  |                 prev_membership == "join" or prev_membership == "invite" | ||||||
|  |             ): | ||||||
|  |                 return True | ||||||
|  | 
 | ||||||
|             new_priority = MEMBERSHIP_PRIORITY.index(membership) |             new_priority = MEMBERSHIP_PRIORITY.index(membership) | ||||||
|             old_priority = MEMBERSHIP_PRIORITY.index(prev_membership) |             old_priority = MEMBERSHIP_PRIORITY.index(prev_membership) | ||||||
|             if old_priority < new_priority: |             if old_priority < new_priority: | ||||||
|  | |||||||
| @ -23,6 +23,9 @@ from tests.utils import ( | |||||||
| 
 | 
 | ||||||
| from synapse.api.filtering import Filter | from synapse.api.filtering import Filter | ||||||
| from synapse.events import FrozenEvent | from synapse.events import FrozenEvent | ||||||
|  | from synapse.api.errors import SynapseError | ||||||
|  | 
 | ||||||
|  | import jsonschema | ||||||
| 
 | 
 | ||||||
| user_localpart = "test_user" | user_localpart = "test_user" | ||||||
| 
 | 
 | ||||||
| @ -54,6 +57,70 @@ class FilteringTestCase(unittest.TestCase): | |||||||
| 
 | 
 | ||||||
|         self.datastore = hs.get_datastore() |         self.datastore = hs.get_datastore() | ||||||
| 
 | 
 | ||||||
|  |     def test_errors_on_invalid_filters(self): | ||||||
|  |         invalid_filters = [ | ||||||
|  |             {"boom": {}}, | ||||||
|  |             {"account_data": "Hello World"}, | ||||||
|  |             {"event_fields": ["\\foo"]}, | ||||||
|  |             {"room": {"timeline": {"limit": 0}, "state": {"not_bars": ["*"]}}}, | ||||||
|  |             {"event_format": "other"}, | ||||||
|  |             {"room": {"not_rooms": ["#foo:pik-test"]}}, | ||||||
|  |             {"presence": {"senders": ["@bar;pik.test.com"]}} | ||||||
|  |         ] | ||||||
|  |         for filter in invalid_filters: | ||||||
|  |             with self.assertRaises(SynapseError) as check_filter_error: | ||||||
|  |                 self.filtering.check_valid_filter(filter) | ||||||
|  |                 self.assertIsInstance(check_filter_error.exception, SynapseError) | ||||||
|  | 
 | ||||||
|  |     def test_valid_filters(self): | ||||||
|  |         valid_filters = [ | ||||||
|  |             { | ||||||
|  |                 "room": { | ||||||
|  |                     "timeline": {"limit": 20}, | ||||||
|  |                     "state": {"not_types": ["m.room.member"]}, | ||||||
|  |                     "ephemeral": {"limit": 0, "not_types": ["*"]}, | ||||||
|  |                     "include_leave": False, | ||||||
|  |                     "rooms": ["!dee:pik-test"], | ||||||
|  |                     "not_rooms": ["!gee:pik-test"], | ||||||
|  |                     "account_data": {"limit": 0, "types": ["*"]} | ||||||
|  |                 } | ||||||
|  |             }, | ||||||
|  |             { | ||||||
|  |                 "room": { | ||||||
|  |                     "state": { | ||||||
|  |                         "types": ["m.room.*"], | ||||||
|  |                         "not_rooms": ["!726s6s6q:example.com"] | ||||||
|  |                     }, | ||||||
|  |                     "timeline": { | ||||||
|  |                         "limit": 10, | ||||||
|  |                         "types": ["m.room.message"], | ||||||
|  |                         "not_rooms": ["!726s6s6q:example.com"], | ||||||
|  |                         "not_senders": ["@spam:example.com"] | ||||||
|  |                     }, | ||||||
|  |                     "ephemeral": { | ||||||
|  |                         "types": ["m.receipt", "m.typing"], | ||||||
|  |                         "not_rooms": ["!726s6s6q:example.com"], | ||||||
|  |                         "not_senders": ["@spam:example.com"] | ||||||
|  |                     } | ||||||
|  |                 }, | ||||||
|  |                 "presence": { | ||||||
|  |                     "types": ["m.presence"], | ||||||
|  |                     "not_senders": ["@alice:example.com"] | ||||||
|  |                 }, | ||||||
|  |                 "event_format": "client", | ||||||
|  |                 "event_fields": ["type", "content", "sender"] | ||||||
|  |             } | ||||||
|  |         ] | ||||||
|  |         for filter in valid_filters: | ||||||
|  |             try: | ||||||
|  |                 self.filtering.check_valid_filter(filter) | ||||||
|  |             except jsonschema.ValidationError as e: | ||||||
|  |                 self.fail(e) | ||||||
|  | 
 | ||||||
|  |     def test_limits_are_applied(self): | ||||||
|  |         # TODO | ||||||
|  |         pass | ||||||
|  | 
 | ||||||
|     def test_definition_types_works_with_literals(self): |     def test_definition_types_works_with_literals(self): | ||||||
|         definition = { |         definition = { | ||||||
|             "types": ["m.room.message", "org.matrix.foo.bar"] |             "types": ["m.room.message", "org.matrix.foo.bar"] | ||||||
|  | |||||||
| @ -93,6 +93,7 @@ class DirectoryTestCase(unittest.TestCase): | |||||||
|                 "room_alias": "#another:remote", |                 "room_alias": "#another:remote", | ||||||
|             }, |             }, | ||||||
|             retry_on_dns_fail=False, |             retry_on_dns_fail=False, | ||||||
|  |             ignore_backoff=True, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|  | |||||||
| @ -324,7 +324,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): | |||||||
|         state = UserPresenceState.default(user_id) |         state = UserPresenceState.default(user_id) | ||||||
|         state = state.copy_and_replace( |         state = state.copy_and_replace( | ||||||
|             state=PresenceState.ONLINE, |             state=PresenceState.ONLINE, | ||||||
|             last_active_ts=now, |             last_active_ts=0, | ||||||
|             last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1, |             last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -119,7 +119,8 @@ class ProfileTestCase(unittest.TestCase): | |||||||
|         self.mock_federation.make_query.assert_called_with( |         self.mock_federation.make_query.assert_called_with( | ||||||
|             destination="remote", |             destination="remote", | ||||||
|             query_type="profile", |             query_type="profile", | ||||||
|             args={"user_id": "@alice:remote", "field": "displayname"} |             args={"user_id": "@alice:remote", "field": "displayname"}, | ||||||
|  |             ignore_backoff=True, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|  | |||||||
| @ -192,6 +192,7 @@ class TypingNotificationsTestCase(unittest.TestCase): | |||||||
|                 ), |                 ), | ||||||
|                 json_data_callback=ANY, |                 json_data_callback=ANY, | ||||||
|                 long_retries=True, |                 long_retries=True, | ||||||
|  |                 backoff_on_404=True, | ||||||
|             ), |             ), | ||||||
|             defer.succeed((200, "OK")) |             defer.succeed((200, "OK")) | ||||||
|         ) |         ) | ||||||
| @ -263,6 +264,7 @@ class TypingNotificationsTestCase(unittest.TestCase): | |||||||
|                 ), |                 ), | ||||||
|                 json_data_callback=ANY, |                 json_data_callback=ANY, | ||||||
|                 long_retries=True, |                 long_retries=True, | ||||||
|  |                 backoff_on_404=True, | ||||||
|             ), |             ), | ||||||
|             defer.succeed((200, "OK")) |             defer.succeed((200, "OK")) | ||||||
|         ) |         ) | ||||||
|  | |||||||
| @ -68,7 +68,7 @@ class ReplicationResourceCase(unittest.TestCase): | |||||||
|         code, body = yield get |         code, body = yield get | ||||||
|         self.assertEquals(code, 200) |         self.assertEquals(code, 200) | ||||||
|         self.assertEquals(body["events"]["field_names"], [ |         self.assertEquals(body["events"]["field_names"], [ | ||||||
|             "position", "internal", "json", "state_group" |             "position", "event_id", "room_id", "type", "state_key", | ||||||
|         ]) |         ]) | ||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|  | |||||||
| @ -33,8 +33,8 @@ PATH_PREFIX = "/_matrix/client/v2_alpha" | |||||||
| class FilterTestCase(unittest.TestCase): | class FilterTestCase(unittest.TestCase): | ||||||
| 
 | 
 | ||||||
|     USER_ID = "@apple:test" |     USER_ID = "@apple:test" | ||||||
|     EXAMPLE_FILTER = {"type": ["m.*"]} |     EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}} | ||||||
|     EXAMPLE_FILTER_JSON = '{"type": ["m.*"]}' |     EXAMPLE_FILTER_JSON = '{"room": {"timeline": {"types": ["m.room.message"]}}}' | ||||||
|     TO_REGISTER = [filter] |     TO_REGISTER = [filter] | ||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|  | |||||||
| @ -89,7 +89,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): | |||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def test_select_one_1col(self): |     def test_select_one_1col(self): | ||||||
|         self.mock_txn.rowcount = 1 |         self.mock_txn.rowcount = 1 | ||||||
|         self.mock_txn.fetchall.return_value = [("Value",)] |         self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)])) | ||||||
| 
 | 
 | ||||||
|         value = yield self.datastore._simple_select_one_onecol( |         value = yield self.datastore._simple_select_one_onecol( | ||||||
|             table="tablename", |             table="tablename", | ||||||
| @ -136,7 +136,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): | |||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def test_select_list(self): |     def test_select_list(self): | ||||||
|         self.mock_txn.rowcount = 3 |         self.mock_txn.rowcount = 3 | ||||||
|         self.mock_txn.fetchall.return_value = ((1,), (2,), (3,)) |         self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) | ||||||
|         self.mock_txn.description = ( |         self.mock_txn.description = ( | ||||||
|             ("colA", None, None, None, None, None, None), |             ("colA", None, None, None, None, None, None), | ||||||
|         ) |         ) | ||||||
|  | |||||||
							
								
								
									
										53
									
								
								tests/storage/test_keys.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								tests/storage/test_keys.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,53 @@ | |||||||
|  | # -*- coding: utf-8 -*- | ||||||
|  | # Copyright 2017 Vector Creations Ltd | ||||||
|  | # | ||||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | # you may not use this file except in compliance with the License. | ||||||
|  | # You may obtain a copy of the License at | ||||||
|  | # | ||||||
|  | #     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | # | ||||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | # See the License for the specific language governing permissions and | ||||||
|  | # limitations under the License. | ||||||
|  | 
 | ||||||
|  | import signedjson.key | ||||||
|  | from twisted.internet import defer | ||||||
|  | 
 | ||||||
|  | import tests.unittest | ||||||
|  | import tests.utils | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class KeyStoreTestCase(tests.unittest.TestCase): | ||||||
|  |     def __init__(self, *args, **kwargs): | ||||||
|  |         super(KeyStoreTestCase, self).__init__(*args, **kwargs) | ||||||
|  |         self.store = None  # type: synapse.storage.keys.KeyStore | ||||||
|  | 
 | ||||||
|  |     @defer.inlineCallbacks | ||||||
|  |     def setUp(self): | ||||||
|  |         hs = yield tests.utils.setup_test_homeserver() | ||||||
|  |         self.store = hs.get_datastore() | ||||||
|  | 
 | ||||||
|  |     @defer.inlineCallbacks | ||||||
|  |     def test_get_server_verify_keys(self): | ||||||
|  |         key1 = signedjson.key.decode_verify_key_base64( | ||||||
|  |             "ed25519", "key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw" | ||||||
|  |         ) | ||||||
|  |         key2 = signedjson.key.decode_verify_key_base64( | ||||||
|  |             "ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw" | ||||||
|  |         ) | ||||||
|  |         yield self.store.store_server_verify_key( | ||||||
|  |             "server1", "from_server", 0, key1 | ||||||
|  |         ) | ||||||
|  |         yield self.store.store_server_verify_key( | ||||||
|  |             "server1", "from_server", 0, key2 | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         res = yield self.store.get_server_verify_keys( | ||||||
|  |             "server1", ["ed25519:key1", "ed25519:key2", "ed25519:key3"]) | ||||||
|  | 
 | ||||||
|  |         self.assertEqual(len(res.keys()), 2) | ||||||
|  |         self.assertEqual(res["ed25519:key1"].version, "key1") | ||||||
|  |         self.assertEqual(res["ed25519:key2"].version, "key2") | ||||||
							
								
								
									
										14
									
								
								tests/util/caches/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								tests/util/caches/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,14 @@ | |||||||
|  | # -*- coding: utf-8 -*- | ||||||
|  | # Copyright 2017 Vector Creations Ltd | ||||||
|  | # | ||||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | # you may not use this file except in compliance with the License. | ||||||
|  | # You may obtain a copy of the License at | ||||||
|  | # | ||||||
|  | #     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | # | ||||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | # See the License for the specific language governing permissions and | ||||||
|  | # limitations under the License. | ||||||
							
								
								
									
										177
									
								
								tests/util/caches/test_descriptors.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										177
									
								
								tests/util/caches/test_descriptors.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,177 @@ | |||||||
|  | # -*- coding: utf-8 -*- | ||||||
|  | # Copyright 2016 OpenMarket Ltd | ||||||
|  | # | ||||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | # you may not use this file except in compliance with the License. | ||||||
|  | # You may obtain a copy of the License at | ||||||
|  | # | ||||||
|  | #     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | # | ||||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | # See the License for the specific language governing permissions and | ||||||
|  | # limitations under the License. | ||||||
|  | import logging | ||||||
|  | 
 | ||||||
|  | import mock | ||||||
|  | from synapse.api.errors import SynapseError | ||||||
|  | from synapse.util import async | ||||||
|  | from synapse.util import logcontext | ||||||
|  | from twisted.internet import defer | ||||||
|  | from synapse.util.caches import descriptors | ||||||
|  | from tests import unittest | ||||||
|  | 
 | ||||||
|  | logger = logging.getLogger(__name__) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class DescriptorTestCase(unittest.TestCase): | ||||||
|  |     @defer.inlineCallbacks | ||||||
|  |     def test_cache(self): | ||||||
|  |         class Cls(object): | ||||||
|  |             def __init__(self): | ||||||
|  |                 self.mock = mock.Mock() | ||||||
|  | 
 | ||||||
|  |             @descriptors.cached() | ||||||
|  |             def fn(self, arg1, arg2): | ||||||
|  |                 return self.mock(arg1, arg2) | ||||||
|  | 
 | ||||||
|  |         obj = Cls() | ||||||
|  | 
 | ||||||
|  |         obj.mock.return_value = 'fish' | ||||||
|  |         r = yield obj.fn(1, 2) | ||||||
|  |         self.assertEqual(r, 'fish') | ||||||
|  |         obj.mock.assert_called_once_with(1, 2) | ||||||
|  |         obj.mock.reset_mock() | ||||||
|  | 
 | ||||||
|  |         # a call with different params should call the mock again | ||||||
|  |         obj.mock.return_value = 'chips' | ||||||
|  |         r = yield obj.fn(1, 3) | ||||||
|  |         self.assertEqual(r, 'chips') | ||||||
|  |         obj.mock.assert_called_once_with(1, 3) | ||||||
|  |         obj.mock.reset_mock() | ||||||
|  | 
 | ||||||
|  |         # the two values should now be cached | ||||||
|  |         r = yield obj.fn(1, 2) | ||||||
|  |         self.assertEqual(r, 'fish') | ||||||
|  |         r = yield obj.fn(1, 3) | ||||||
|  |         self.assertEqual(r, 'chips') | ||||||
|  |         obj.mock.assert_not_called() | ||||||
|  | 
 | ||||||
|  |     @defer.inlineCallbacks | ||||||
|  |     def test_cache_num_args(self): | ||||||
|  |         """Only the first num_args arguments should matter to the cache""" | ||||||
|  | 
 | ||||||
|  |         class Cls(object): | ||||||
|  |             def __init__(self): | ||||||
|  |                 self.mock = mock.Mock() | ||||||
|  | 
 | ||||||
|  |             @descriptors.cached(num_args=1) | ||||||
|  |             def fn(self, arg1, arg2): | ||||||
|  |                 return self.mock(arg1, arg2) | ||||||
|  | 
 | ||||||
|  |         obj = Cls() | ||||||
|  |         obj.mock.return_value = 'fish' | ||||||
|  |         r = yield obj.fn(1, 2) | ||||||
|  |         self.assertEqual(r, 'fish') | ||||||
|  |         obj.mock.assert_called_once_with(1, 2) | ||||||
|  |         obj.mock.reset_mock() | ||||||
|  | 
 | ||||||
|  |         # a call with different params should call the mock again | ||||||
|  |         obj.mock.return_value = 'chips' | ||||||
|  |         r = yield obj.fn(2, 3) | ||||||
|  |         self.assertEqual(r, 'chips') | ||||||
|  |         obj.mock.assert_called_once_with(2, 3) | ||||||
|  |         obj.mock.reset_mock() | ||||||
|  | 
 | ||||||
|  |         # the two values should now be cached; we should be able to vary | ||||||
|  |         # the second argument and still get the cached result. | ||||||
|  |         r = yield obj.fn(1, 4) | ||||||
|  |         self.assertEqual(r, 'fish') | ||||||
|  |         r = yield obj.fn(2, 5) | ||||||
|  |         self.assertEqual(r, 'chips') | ||||||
|  |         obj.mock.assert_not_called() | ||||||
|  | 
 | ||||||
|  |     def test_cache_logcontexts(self): | ||||||
|  |         """Check that logcontexts are set and restored correctly when | ||||||
|  |         using the cache.""" | ||||||
|  | 
 | ||||||
|  |         complete_lookup = defer.Deferred() | ||||||
|  | 
 | ||||||
|  |         class Cls(object): | ||||||
|  |             @descriptors.cached() | ||||||
|  |             def fn(self, arg1): | ||||||
|  |                 @defer.inlineCallbacks | ||||||
|  |                 def inner_fn(): | ||||||
|  |                     with logcontext.PreserveLoggingContext(): | ||||||
|  |                         yield complete_lookup | ||||||
|  |                     defer.returnValue(1) | ||||||
|  | 
 | ||||||
|  |                 return inner_fn() | ||||||
|  | 
 | ||||||
|  |         @defer.inlineCallbacks | ||||||
|  |         def do_lookup(): | ||||||
|  |             with logcontext.LoggingContext() as c1: | ||||||
|  |                 c1.name = "c1" | ||||||
|  |                 r = yield obj.fn(1) | ||||||
|  |                 self.assertEqual(logcontext.LoggingContext.current_context(), | ||||||
|  |                                  c1) | ||||||
|  |             defer.returnValue(r) | ||||||
|  | 
 | ||||||
|  |         def check_result(r): | ||||||
|  |             self.assertEqual(r, 1) | ||||||
|  | 
 | ||||||
|  |         obj = Cls() | ||||||
|  | 
 | ||||||
|  |         # set off a deferred which will do a cache lookup | ||||||
|  |         d1 = do_lookup() | ||||||
|  |         self.assertEqual(logcontext.LoggingContext.current_context(), | ||||||
|  |                          logcontext.LoggingContext.sentinel) | ||||||
|  |         d1.addCallback(check_result) | ||||||
|  | 
 | ||||||
|  |         # and another | ||||||
|  |         d2 = do_lookup() | ||||||
|  |         self.assertEqual(logcontext.LoggingContext.current_context(), | ||||||
|  |                          logcontext.LoggingContext.sentinel) | ||||||
|  |         d2.addCallback(check_result) | ||||||
|  | 
 | ||||||
|  |         # let the lookup complete | ||||||
|  |         complete_lookup.callback(None) | ||||||
|  | 
 | ||||||
|  |         return defer.gatherResults([d1, d2]) | ||||||
|  | 
 | ||||||
|  |     def test_cache_logcontexts_with_exception(self): | ||||||
|  |         """Check that the cache sets and restores logcontexts correctly when | ||||||
|  |         the lookup function throws an exception""" | ||||||
|  | 
 | ||||||
|  |         class Cls(object): | ||||||
|  |             @descriptors.cached() | ||||||
|  |             def fn(self, arg1): | ||||||
|  |                 @defer.inlineCallbacks | ||||||
|  |                 def inner_fn(): | ||||||
|  |                     yield async.run_on_reactor() | ||||||
|  |                     raise SynapseError(400, "blah") | ||||||
|  | 
 | ||||||
|  |                 return inner_fn() | ||||||
|  | 
 | ||||||
|  |         @defer.inlineCallbacks | ||||||
|  |         def do_lookup(): | ||||||
|  |             with logcontext.LoggingContext() as c1: | ||||||
|  |                 c1.name = "c1" | ||||||
|  |                 try: | ||||||
|  |                     yield obj.fn(1) | ||||||
|  |                     self.fail("No exception thrown") | ||||||
|  |                 except SynapseError: | ||||||
|  |                     pass | ||||||
|  | 
 | ||||||
|  |                 self.assertEqual(logcontext.LoggingContext.current_context(), | ||||||
|  |                                  c1) | ||||||
|  | 
 | ||||||
|  |         obj = Cls() | ||||||
|  | 
 | ||||||
|  |         # set off a deferred which will do a cache lookup | ||||||
|  |         d1 = do_lookup() | ||||||
|  |         self.assertEqual(logcontext.LoggingContext.current_context(), | ||||||
|  |                          logcontext.LoggingContext.sentinel) | ||||||
|  | 
 | ||||||
|  |         return d1 | ||||||
							
								
								
									
										33
									
								
								tests/util/test_clock.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								tests/util/test_clock.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,33 @@ | |||||||
|  | # -*- coding: utf-8 -*- | ||||||
|  | # Copyright 2017 Vector Creations Ltd | ||||||
|  | # | ||||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | # you may not use this file except in compliance with the License. | ||||||
|  | # You may obtain a copy of the License at | ||||||
|  | # | ||||||
|  | #     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | # | ||||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | # See the License for the specific language governing permissions and | ||||||
|  | # limitations under the License. | ||||||
|  | from synapse import util | ||||||
|  | from twisted.internet import defer | ||||||
|  | from tests import unittest | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class ClockTestCase(unittest.TestCase): | ||||||
|  |     @defer.inlineCallbacks | ||||||
|  |     def test_time_bound_deferred(self): | ||||||
|  |         # just a deferred which never resolves | ||||||
|  |         slow_deferred = defer.Deferred() | ||||||
|  | 
 | ||||||
|  |         clock = util.Clock() | ||||||
|  |         time_bound = clock.time_bound_deferred(slow_deferred, 0.001) | ||||||
|  | 
 | ||||||
|  |         try: | ||||||
|  |             yield time_bound | ||||||
|  |             self.fail("Expected timedout error, but got nothing") | ||||||
|  |         except util.DeferredTimedOutError: | ||||||
|  |             pass | ||||||
| @ -1,8 +1,10 @@ | |||||||
|  | import twisted.python.failure | ||||||
| from twisted.internet import defer | from twisted.internet import defer | ||||||
| from twisted.internet import reactor | from twisted.internet import reactor | ||||||
| from .. import unittest | from .. import unittest | ||||||
| 
 | 
 | ||||||
| from synapse.util.async import sleep | from synapse.util.async import sleep | ||||||
|  | from synapse.util import logcontext | ||||||
| from synapse.util.logcontext import LoggingContext | from synapse.util.logcontext import LoggingContext | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -33,3 +35,62 @@ class LoggingContextTestCase(unittest.TestCase): | |||||||
|             context_one.test_key = "one" |             context_one.test_key = "one" | ||||||
|             yield sleep(0) |             yield sleep(0) | ||||||
|             self._check_test_key("one") |             self._check_test_key("one") | ||||||
|  | 
 | ||||||
|  |     def _test_preserve_fn(self, function): | ||||||
|  |         sentinel_context = LoggingContext.current_context() | ||||||
|  | 
 | ||||||
|  |         callback_completed = [False] | ||||||
|  | 
 | ||||||
|  |         @defer.inlineCallbacks | ||||||
|  |         def cb(): | ||||||
|  |             context_one.test_key = "one" | ||||||
|  |             yield function() | ||||||
|  |             self._check_test_key("one") | ||||||
|  | 
 | ||||||
|  |             callback_completed[0] = True | ||||||
|  | 
 | ||||||
|  |         with LoggingContext() as context_one: | ||||||
|  |             context_one.test_key = "one" | ||||||
|  | 
 | ||||||
|  |             # fire off function, but don't wait on it. | ||||||
|  |             logcontext.preserve_fn(cb)() | ||||||
|  | 
 | ||||||
|  |             self._check_test_key("one") | ||||||
|  | 
 | ||||||
|  |         # now wait for the function under test to have run, and check that | ||||||
|  |         # the logcontext is left in a sane state. | ||||||
|  |         d2 = defer.Deferred() | ||||||
|  | 
 | ||||||
|  |         def check_logcontext(): | ||||||
|  |             if not callback_completed[0]: | ||||||
|  |                 reactor.callLater(0.01, check_logcontext) | ||||||
|  |                 return | ||||||
|  | 
 | ||||||
|  |             # make sure that the context was reset before it got thrown back | ||||||
|  |             # into the reactor | ||||||
|  |             try: | ||||||
|  |                 self.assertIs(LoggingContext.current_context(), | ||||||
|  |                               sentinel_context) | ||||||
|  |                 d2.callback(None) | ||||||
|  |             except BaseException: | ||||||
|  |                 d2.errback(twisted.python.failure.Failure()) | ||||||
|  | 
 | ||||||
|  |         reactor.callLater(0.01, check_logcontext) | ||||||
|  | 
 | ||||||
|  |         # test is done once d2 finishes | ||||||
|  |         return d2 | ||||||
|  | 
 | ||||||
|  |     def test_preserve_fn_with_blocking_fn(self): | ||||||
|  |         @defer.inlineCallbacks | ||||||
|  |         def blocking_function(): | ||||||
|  |             yield sleep(0) | ||||||
|  | 
 | ||||||
|  |         return self._test_preserve_fn(blocking_function) | ||||||
|  | 
 | ||||||
|  |     def test_preserve_fn_with_non_blocking_fn(self): | ||||||
|  |         @defer.inlineCallbacks | ||||||
|  |         def nonblocking_function(): | ||||||
|  |             with logcontext.PreserveLoggingContext(): | ||||||
|  |                 yield defer.succeed(None) | ||||||
|  | 
 | ||||||
|  |         return self._test_preserve_fn(nonblocking_function) | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user