mirror of
				https://github.com/matrix-org/synapse.git
				synced 2025-10-31 00:01:33 +01:00 
			
		
		
		
	Merge branch 'release-v0.20.0' of github.com:matrix-org/synapse
This commit is contained in:
		
						commit
						4902db1fc9
					
				
							
								
								
									
										53
									
								
								CHANGES.rst
									
									
									
									
									
								
							
							
						
						
									
										53
									
								
								CHANGES.rst
									
									
									
									
									
								
							| @ -1,3 +1,56 @@ | ||||
| Changes in synapse v0.20.0 (2017-04-11) | ||||
| ======================================= | ||||
| 
 | ||||
| Bug fixes: | ||||
| 
 | ||||
| * Fix joining rooms over federation where not all servers in the room saw the | ||||
|   new server had joined (PR #2094) | ||||
| 
 | ||||
| 
 | ||||
| Changes in synapse v0.20.0-rc1 (2017-03-30) | ||||
| =========================================== | ||||
| 
 | ||||
| Features: | ||||
| 
 | ||||
| * Add delete_devices API (PR #1993) | ||||
| * Add phone number registration/login support (PR #1994, #2055) | ||||
| 
 | ||||
| 
 | ||||
| Changes: | ||||
| 
 | ||||
| * Use JSONSchema for validation of filters. Thanks @pik! (PR #1783) | ||||
| * Reread log config on SIGHUP (PR #1982) | ||||
| * Speed up public room list (PR #1989) | ||||
| * Add helpful texts to logger config options (PR #1990) | ||||
| * Minor ``/sync`` performance improvements. (PR #2002, #2013, #2022) | ||||
| * Add some debug to help diagnose weird federation issue (PR #2035) | ||||
| * Correctly limit retries for all federation requests (PR #2050, #2061) | ||||
| * Don't lock table when persisting new one time keys (PR #2053) | ||||
| * Reduce some CPU work on DB threads (PR #2054) | ||||
| * Cache hosts in room (PR #2060) | ||||
| * Batch sending of device list pokes (PR #2063) | ||||
| * Speed up persist event path in certain edge cases (PR #2070) | ||||
| 
 | ||||
| 
 | ||||
| Bug fixes: | ||||
| 
 | ||||
| * Fix bug where current_state_events renamed to current_state_ids (PR #1849) | ||||
| * Fix routing loop when fetching remote media (PR #1992) | ||||
| * Fix current_state_events table to not lie (PR #1996) | ||||
| * Fix CAS login to handle PartialDownloadError (PR #1997) | ||||
| * Fix assertion to stop transaction queue getting wedged (PR #2010) | ||||
| * Fix presence to fallback to last_active_ts if it beats the last sync time. | ||||
|   Thanks @Half-Shot! (PR #2014) | ||||
| * Fix bug when federation received a PDU while a room join is in progress (PR | ||||
|   #2016) | ||||
| * Fix resetting state on rejected events (PR #2025) | ||||
| * Fix installation issues in readme. Thanks @ricco386 (PR #2037) | ||||
| * Fix caching of remote servers' signature keys (PR #2042) | ||||
| * Fix some leaking log context (PR #2048, #2049, #2057, #2058) | ||||
| * Fix rejection of invites not reaching sync (PR #2056) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| Changes in synapse v0.19.3 (2017-03-20) | ||||
| ======================================= | ||||
| 
 | ||||
|  | ||||
| @ -146,6 +146,7 @@ To install the synapse homeserver run:: | ||||
| 
 | ||||
|     virtualenv -p python2.7 ~/.synapse | ||||
|     source ~/.synapse/bin/activate | ||||
|     pip install --upgrade pip | ||||
|     pip install --upgrade setuptools | ||||
|     pip install https://github.com/matrix-org/synapse/tarball/master | ||||
| 
 | ||||
| @ -228,6 +229,7 @@ To get started, it is easiest to use the command line to register new users:: | ||||
|     New user localpart: erikj | ||||
|     Password: | ||||
|     Confirm password: | ||||
|     Make admin [no]: | ||||
|     Success! | ||||
| 
 | ||||
| This process uses a setting ``registration_shared_secret`` in | ||||
| @ -808,7 +810,7 @@ directory of your choice:: | ||||
| Synapse has a number of external dependencies, that are easiest | ||||
| to install using pip and a virtualenv:: | ||||
| 
 | ||||
|     virtualenv env | ||||
|     virtualenv -p python2.7 env | ||||
|     source env/bin/activate | ||||
|     python synapse/python_dependencies.py | xargs pip install | ||||
|     pip install lxml mock | ||||
|  | ||||
| @ -39,9 +39,11 @@ loggers: | ||||
|     synapse: | ||||
|         level: INFO | ||||
| 
 | ||||
|     synapse.storage: | ||||
|     synapse.storage.SQL: | ||||
|         # beware: increasing this to DEBUG will make synapse log sensitive | ||||
|         # information such as access tokens. | ||||
|         level: INFO | ||||
|      | ||||
| 
 | ||||
|     # example of enabling debugging for a component: | ||||
|     # | ||||
|     # synapse.federation.transport.server: | ||||
|  | ||||
| @ -1,10 +1,446 @@ | ||||
| What do I do about "Unexpected logging context" debug log-lines everywhere? | ||||
| Log contexts | ||||
| ============ | ||||
| 
 | ||||
| <Mjark> The logging context lives in thread local storage | ||||
| <Mjark> Sometimes it gets out of sync with what it should actually be, usually because something scheduled something to run on the reactor without preserving the logging context.  | ||||
| <Matthew> what is the impact of it getting out of sync? and how and when should we preserve log context? | ||||
| <Mjark> The impact is that some of the CPU and database metrics will be under-reported, and some log lines will be mis-attributed. | ||||
| <Mjark> It should happen auto-magically in all the APIs that do IO or otherwise defer to the reactor. | ||||
| <Erik> Mjark: the other place is if we branch, e.g. using defer.gatherResults | ||||
| .. contents:: | ||||
| 
 | ||||
| Unanswered: how and when should we preserve log context? | ||||
| To help track the processing of individual requests, synapse uses a | ||||
| 'log context' to track which request it is handling at any given moment. This | ||||
| is done via a thread-local variable; a ``logging.Filter`` is then used to fish | ||||
| the information back out of the thread-local variable and add it to each log | ||||
| record. | ||||
| 
 | ||||
| Logcontexts are also used for CPU and database accounting, so that we can track | ||||
| which requests were responsible for high CPU use or database activity. | ||||
| 
 | ||||
| The ``synapse.util.logcontext`` module provides a facilities for managing the | ||||
| current log context (as well as providing the ``LoggingContextFilter`` class). | ||||
| 
 | ||||
| Deferreds make the whole thing complicated, so this document describes how it | ||||
| all works, and how to write code which follows the rules. | ||||
| 
 | ||||
| Logcontexts without Deferreds | ||||
| ----------------------------- | ||||
| 
 | ||||
| In the absence of any Deferred voodoo, things are simple enough. As with any | ||||
| code of this nature, the rule is that our function should leave things as it | ||||
| found them: | ||||
| 
 | ||||
| .. code:: python | ||||
| 
 | ||||
|     from synapse.util import logcontext         # omitted from future snippets | ||||
| 
 | ||||
|     def handle_request(request_id): | ||||
|         request_context = logcontext.LoggingContext() | ||||
| 
 | ||||
|         calling_context = logcontext.LoggingContext.current_context() | ||||
|         logcontext.LoggingContext.set_current_context(request_context) | ||||
|         try: | ||||
|             request_context.request = request_id | ||||
|             do_request_handling() | ||||
|             logger.debug("finished") | ||||
|         finally: | ||||
|             logcontext.LoggingContext.set_current_context(calling_context) | ||||
| 
 | ||||
|     def do_request_handling(): | ||||
|         logger.debug("phew")  # this will be logged against request_id | ||||
| 
 | ||||
| 
 | ||||
| LoggingContext implements the context management methods, so the above can be | ||||
| written much more succinctly as: | ||||
| 
 | ||||
| .. code:: python | ||||
| 
 | ||||
|     def handle_request(request_id): | ||||
|         with logcontext.LoggingContext() as request_context: | ||||
|             request_context.request = request_id | ||||
|             do_request_handling() | ||||
|             logger.debug("finished") | ||||
| 
 | ||||
|     def do_request_handling(): | ||||
|         logger.debug("phew") | ||||
| 
 | ||||
| 
 | ||||
| Using logcontexts with Deferreds | ||||
| -------------------------------- | ||||
| 
 | ||||
| Deferreds — and in particular, ``defer.inlineCallbacks`` — break | ||||
| the linear flow of code so that there is no longer a single entry point where | ||||
| we should set the logcontext and a single exit point where we should remove it. | ||||
| 
 | ||||
| Consider the example above, where ``do_request_handling`` needs to do some | ||||
| blocking operation, and returns a deferred: | ||||
| 
 | ||||
| .. code:: python | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def handle_request(request_id): | ||||
|         with logcontext.LoggingContext() as request_context: | ||||
|             request_context.request = request_id | ||||
|             yield do_request_handling() | ||||
|             logger.debug("finished") | ||||
| 
 | ||||
| 
 | ||||
| In the above flow: | ||||
| 
 | ||||
| * The logcontext is set | ||||
| * ``do_request_handling`` is called, and returns a deferred | ||||
| * ``handle_request`` yields the deferred | ||||
| * The ``inlineCallbacks`` wrapper of ``handle_request`` returns a deferred | ||||
| 
 | ||||
| So we have stopped processing the request (and will probably go on to start | ||||
| processing the next), without clearing the logcontext. | ||||
| 
 | ||||
| To circumvent this problem, synapse code assumes that, wherever you have a | ||||
| deferred, you will want to yield on it. To that end, whereever functions return | ||||
| a deferred, we adopt the following conventions: | ||||
| 
 | ||||
| **Rules for functions returning deferreds:** | ||||
| 
 | ||||
|   * If the deferred is already complete, the function returns with the same | ||||
|     logcontext it started with. | ||||
|   * If the deferred is incomplete, the function clears the logcontext before | ||||
|     returning; when the deferred completes, it restores the logcontext before | ||||
|     running any callbacks. | ||||
| 
 | ||||
| That sounds complicated, but actually it means a lot of code (including the | ||||
| example above) "just works". There are two cases: | ||||
| 
 | ||||
| * If ``do_request_handling`` returns a completed deferred, then the logcontext | ||||
|   will still be in place. In this case, execution will continue immediately | ||||
|   after the ``yield``; the "finished" line will be logged against the right | ||||
|   context, and the ``with`` block restores the original context before we | ||||
|   return to the caller. | ||||
| 
 | ||||
| * If the returned deferred is incomplete, ``do_request_handling`` clears the | ||||
|   logcontext before returning. The logcontext is therefore clear when | ||||
|   ``handle_request`` yields the deferred. At that point, the ``inlineCallbacks`` | ||||
|   wrapper adds a callback to the deferred, and returns another (incomplete) | ||||
|   deferred to the caller, and it is safe to begin processing the next request. | ||||
| 
 | ||||
|   Once ``do_request_handling``'s deferred completes, it will reinstate the | ||||
|   logcontext, before running the callback added by the ``inlineCallbacks`` | ||||
|   wrapper. That callback runs the second half of ``handle_request``, so again | ||||
|   the "finished" line will be logged against the right | ||||
|   context, and the ``with`` block restores the original context. | ||||
| 
 | ||||
| As an aside, it's worth noting that ``handle_request`` follows our rules - | ||||
| though that only matters if the caller has its own logcontext which it cares | ||||
| about. | ||||
| 
 | ||||
| The following sections describe pitfalls and helpful patterns when implementing | ||||
| these rules. | ||||
| 
 | ||||
| Always yield your deferreds | ||||
| --------------------------- | ||||
| 
 | ||||
| Whenever you get a deferred back from a function, you should ``yield`` on it | ||||
| as soon as possible. (Returning it directly to your caller is ok too, if you're | ||||
| not doing ``inlineCallbacks``.) Do not pass go; do not do any logging; do not | ||||
| call any other functions. | ||||
| 
 | ||||
| .. code:: python | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def fun(): | ||||
|         logger.debug("starting") | ||||
|         yield do_some_stuff()       # just like this | ||||
| 
 | ||||
|         d = more_stuff() | ||||
|         result = yield d            # also fine, of course | ||||
| 
 | ||||
|         defer.returnValue(result) | ||||
| 
 | ||||
|     def nonInlineCallbacksFun(): | ||||
|         logger.debug("just a wrapper really") | ||||
|         return do_some_stuff()      # this is ok too - the caller will yield on | ||||
|                                     # it anyway. | ||||
| 
 | ||||
| Provided this pattern is followed all the way back up to the callchain to where | ||||
| the logcontext was set, this will make things work out ok: provided | ||||
| ``do_some_stuff`` and ``more_stuff`` follow the rules above, then so will | ||||
| ``fun`` (as wrapped by ``inlineCallbacks``) and ``nonInlineCallbacksFun``. | ||||
| 
 | ||||
| It's all too easy to forget to ``yield``: for instance if we forgot that | ||||
| ``do_some_stuff`` returned a deferred, we might plough on regardless. This | ||||
| leads to a mess; it will probably work itself out eventually, but not before | ||||
| a load of stuff has been logged against the wrong content. (Normally, other | ||||
| things will break, more obviously, if you forget to ``yield``, so this tends | ||||
| not to be a major problem in practice.) | ||||
| 
 | ||||
| Of course sometimes you need to do something a bit fancier with your Deferreds | ||||
| - not all code follows the linear A-then-B-then-C pattern. Notes on | ||||
| implementing more complex patterns are in later sections. | ||||
| 
 | ||||
| Where you create a new Deferred, make it follow the rules | ||||
| --------------------------------------------------------- | ||||
| 
 | ||||
| Most of the time, a Deferred comes from another synapse function. Sometimes, | ||||
| though, we need to make up a new Deferred, or we get a Deferred back from | ||||
| external code. We need to make it follow our rules. | ||||
| 
 | ||||
| The easy way to do it is with a combination of ``defer.inlineCallbacks``, and | ||||
| ``logcontext.PreserveLoggingContext``. Suppose we want to implement ``sleep``, | ||||
| which returns a deferred which will run its callbacks after a given number of | ||||
| seconds. That might look like: | ||||
| 
 | ||||
| .. code:: python | ||||
| 
 | ||||
|     # not a logcontext-rules-compliant function | ||||
|     def get_sleep_deferred(seconds): | ||||
|         d = defer.Deferred() | ||||
|         reactor.callLater(seconds, d.callback, None) | ||||
|         return d | ||||
| 
 | ||||
| That doesn't follow the rules, but we can fix it by wrapping it with | ||||
| ``PreserveLoggingContext`` and ``yield`` ing on it: | ||||
| 
 | ||||
| .. code:: python | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def sleep(seconds): | ||||
|         with PreserveLoggingContext(): | ||||
|             yield get_sleep_deferred(seconds) | ||||
| 
 | ||||
| This technique works equally for external functions which return deferreds, | ||||
| or deferreds we have made ourselves. | ||||
| 
 | ||||
| You can also use ``logcontext.make_deferred_yieldable``, which just does the | ||||
| boilerplate for you, so the above could be written: | ||||
| 
 | ||||
| .. code:: python | ||||
| 
 | ||||
|     def sleep(seconds): | ||||
|         return logcontext.make_deferred_yieldable(get_sleep_deferred(seconds)) | ||||
| 
 | ||||
| 
 | ||||
| Fire-and-forget | ||||
| --------------- | ||||
| 
 | ||||
| Sometimes you want to fire off a chain of execution, but not wait for its | ||||
| result. That might look a bit like this: | ||||
| 
 | ||||
| .. code:: python | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def do_request_handling(): | ||||
|         yield foreground_operation() | ||||
| 
 | ||||
|         # *don't* do this | ||||
|         background_operation() | ||||
| 
 | ||||
|         logger.debug("Request handling complete") | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def background_operation(): | ||||
|         yield first_background_step() | ||||
|         logger.debug("Completed first step") | ||||
|         yield second_background_step() | ||||
|         logger.debug("Completed second step") | ||||
| 
 | ||||
| The above code does a couple of steps in the background after | ||||
| ``do_request_handling`` has finished. The log lines are still logged against | ||||
| the ``request_context`` logcontext, which may or may not be desirable. There | ||||
| are two big problems with the above, however. The first problem is that, if | ||||
| ``background_operation`` returns an incomplete Deferred, it will expect its | ||||
| caller to ``yield`` immediately, so will have cleared the logcontext. In this | ||||
| example, that means that 'Request handling complete' will be logged without any | ||||
| context. | ||||
| 
 | ||||
| The second problem, which is potentially even worse, is that when the Deferred | ||||
| returned by ``background_operation`` completes, it will restore the original | ||||
| logcontext. There is nothing waiting on that Deferred, so the logcontext will | ||||
| leak into the reactor and possibly get attached to some arbitrary future | ||||
| operation. | ||||
| 
 | ||||
| There are two potential solutions to this. | ||||
| 
 | ||||
| One option is to surround the call to ``background_operation`` with a | ||||
| ``PreserveLoggingContext`` call. That will reset the logcontext before | ||||
| starting ``background_operation`` (so the context restored when the deferred | ||||
| completes will be the empty logcontext), and will restore the current | ||||
| logcontext before continuing the foreground process: | ||||
| 
 | ||||
| .. code:: python | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def do_request_handling(): | ||||
|         yield foreground_operation() | ||||
| 
 | ||||
|         # start background_operation off in the empty logcontext, to | ||||
|         # avoid leaking the current context into the reactor. | ||||
|         with PreserveLoggingContext(): | ||||
|             background_operation() | ||||
| 
 | ||||
|         # this will now be logged against the request context | ||||
|         logger.debug("Request handling complete") | ||||
| 
 | ||||
| Obviously that option means that the operations done in | ||||
| ``background_operation`` would be not be logged against a logcontext (though | ||||
| that might be fixed by setting a different logcontext via a ``with | ||||
| LoggingContext(...)`` in ``background_operation``). | ||||
| 
 | ||||
| The second option is to use ``logcontext.preserve_fn``, which wraps a function | ||||
| so that it doesn't reset the logcontext even when it returns an incomplete | ||||
| deferred, and adds a callback to the returned deferred to reset the | ||||
| logcontext. In other words, it turns a function that follows the Synapse rules | ||||
| about logcontexts and Deferreds into one which behaves more like an external | ||||
| function — the opposite operation to that described in the previous section. | ||||
| It can be used like this: | ||||
| 
 | ||||
| .. code:: python | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def do_request_handling(): | ||||
|         yield foreground_operation() | ||||
| 
 | ||||
|         logcontext.preserve_fn(background_operation)() | ||||
| 
 | ||||
|         # this will now be logged against the request context | ||||
|         logger.debug("Request handling complete") | ||||
| 
 | ||||
| XXX: I think ``preserve_context_over_fn`` is supposed to do the first option, | ||||
| but the fact that it does ``preserve_context_over_deferred`` on its results | ||||
| means that its use is fraught with difficulty. | ||||
| 
 | ||||
| Passing synapse deferreds into third-party functions | ||||
| ---------------------------------------------------- | ||||
| 
 | ||||
| A typical example of this is where we want to collect together two or more | ||||
| deferred via ``defer.gatherResults``: | ||||
| 
 | ||||
| .. code:: python | ||||
| 
 | ||||
|     d1 = operation1() | ||||
|     d2 = operation2() | ||||
|     d3 = defer.gatherResults([d1, d2]) | ||||
| 
 | ||||
| This is really a variation of the fire-and-forget problem above, in that we are | ||||
| firing off ``d1`` and ``d2`` without yielding on them. The difference | ||||
| is that we now have third-party code attached to their callbacks. Anyway either | ||||
| technique given in the `Fire-and-forget`_ section will work. | ||||
| 
 | ||||
| Of course, the new Deferred returned by ``gatherResults`` needs to be wrapped | ||||
| in order to make it follow the logcontext rules before we can yield it, as | ||||
| described in `Where you create a new Deferred, make it follow the rules`_. | ||||
| 
 | ||||
| So, option one: reset the logcontext before starting the operations to be | ||||
| gathered: | ||||
| 
 | ||||
| .. code:: python | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def do_request_handling(): | ||||
|         with PreserveLoggingContext(): | ||||
|             d1 = operation1() | ||||
|             d2 = operation2() | ||||
|             result = yield defer.gatherResults([d1, d2]) | ||||
| 
 | ||||
| In this case particularly, though, option two, of using | ||||
| ``logcontext.preserve_fn`` almost certainly makes more sense, so that | ||||
| ``operation1`` and ``operation2`` are both logged against the original | ||||
| logcontext. This looks like: | ||||
| 
 | ||||
| .. code:: python | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def do_request_handling(): | ||||
|         d1 = logcontext.preserve_fn(operation1)() | ||||
|         d2 = logcontext.preserve_fn(operation2)() | ||||
| 
 | ||||
|         with PreserveLoggingContext(): | ||||
|             result = yield defer.gatherResults([d1, d2]) | ||||
| 
 | ||||
| 
 | ||||
| Was all this really necessary? | ||||
| ------------------------------ | ||||
| 
 | ||||
| The conventions used work fine for a linear flow where everything happens in | ||||
| series via ``defer.inlineCallbacks`` and ``yield``, but are certainly tricky to | ||||
| follow for any more exotic flows. It's hard not to wonder if we could have done | ||||
| something else. | ||||
| 
 | ||||
| We're not going to rewrite Synapse now, so the following is entirely of | ||||
| academic interest, but I'd like to record some thoughts on an alternative | ||||
| approach. | ||||
| 
 | ||||
| I briefly prototyped some code following an alternative set of rules. I think | ||||
| it would work, but I certainly didn't get as far as thinking how it would | ||||
| interact with concepts as complicated as the cache descriptors. | ||||
| 
 | ||||
| My alternative rules were: | ||||
| 
 | ||||
| * functions always preserve the logcontext of their caller, whether or not they | ||||
|   are returning a Deferred. | ||||
| 
 | ||||
| * Deferreds returned by synapse functions run their callbacks in the same | ||||
|   context as the function was orignally called in. | ||||
| 
 | ||||
| The main point of this scheme is that everywhere that sets the logcontext is | ||||
| responsible for clearing it before returning control to the reactor. | ||||
| 
 | ||||
| So, for example, if you were the function which started a ``with | ||||
| LoggingContext`` block, you wouldn't ``yield`` within it — instead you'd start | ||||
| off the background process, and then leave the ``with`` block to wait for it: | ||||
| 
 | ||||
| .. code:: python | ||||
| 
 | ||||
|     def handle_request(request_id): | ||||
|         with logcontext.LoggingContext() as request_context: | ||||
|             request_context.request = request_id | ||||
|             d = do_request_handling() | ||||
| 
 | ||||
|         def cb(r): | ||||
|             logger.debug("finished") | ||||
| 
 | ||||
|         d.addCallback(cb) | ||||
|         return d | ||||
| 
 | ||||
| (in general, mixing ``with LoggingContext`` blocks and | ||||
| ``defer.inlineCallbacks`` in the same function leads to slighly | ||||
| counter-intuitive code, under this scheme). | ||||
| 
 | ||||
| Because we leave the original ``with`` block as soon as the Deferred is | ||||
| returned (as opposed to waiting for it to be resolved, as we do today), the | ||||
| logcontext is cleared before control passes back to the reactor; so if there is | ||||
| some code within ``do_request_handling`` which needs to wait for a Deferred to | ||||
| complete, there is no need for it to worry about clearing the logcontext before | ||||
| doing so: | ||||
| 
 | ||||
| .. code:: python | ||||
| 
 | ||||
|     def handle_request(): | ||||
|         r = do_some_stuff() | ||||
|         r.addCallback(do_some_more_stuff) | ||||
|         return r | ||||
| 
 | ||||
| — and provided ``do_some_stuff`` follows the rules of returning a Deferred which | ||||
| runs its callbacks in the original logcontext, all is happy. | ||||
| 
 | ||||
| The business of a Deferred which runs its callbacks in the original logcontext | ||||
| isn't hard to achieve — we have it today, in the shape of | ||||
| ``logcontext._PreservingContextDeferred``: | ||||
| 
 | ||||
| .. code:: python | ||||
| 
 | ||||
|     def do_some_stuff(): | ||||
|         deferred = do_some_io() | ||||
|         pcd = _PreservingContextDeferred(LoggingContext.current_context()) | ||||
|         deferred.chainDeferred(pcd) | ||||
|         return pcd | ||||
| 
 | ||||
| It turns out that, thanks to the way that Deferreds chain together, we | ||||
| automatically get the property of a context-preserving deferred with | ||||
| ``defer.inlineCallbacks``, provided the final Defered the function ``yields`` | ||||
| on has that property. So we can just write: | ||||
| 
 | ||||
| .. code:: python | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def handle_request(): | ||||
|         yield do_some_stuff() | ||||
|         yield do_some_more_stuff() | ||||
| 
 | ||||
| To conclude: I think this scheme would have worked equally well, with less | ||||
| danger of messing it up, and probably made some more esoteric code easier to | ||||
| write. But again — changing the conventions of the entire Synapse codebase is | ||||
| not a sensible option for the marginal improvement offered. | ||||
|  | ||||
| @ -16,4 +16,4 @@ | ||||
| """ This is a reference implementation of a Matrix home server. | ||||
| """ | ||||
| 
 | ||||
| __version__ = "0.19.3" | ||||
| __version__ = "0.20.0" | ||||
|  | ||||
| @ -23,7 +23,7 @@ from synapse import event_auth | ||||
| from synapse.api.constants import EventTypes, Membership, JoinRules | ||||
| from synapse.api.errors import AuthError, Codes | ||||
| from synapse.types import UserID | ||||
| from synapse.util.logcontext import preserve_context_over_fn | ||||
| from synapse.util import logcontext | ||||
| from synapse.util.metrics import Measure | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| @ -209,8 +209,7 @@ class Auth(object): | ||||
|                 default=[""] | ||||
|             )[0] | ||||
|             if user and access_token and ip_addr: | ||||
|                 preserve_context_over_fn( | ||||
|                     self.store.insert_client_ip, | ||||
|                 logcontext.preserve_fn(self.store.insert_client_ip)( | ||||
|                     user=user, | ||||
|                     access_token=access_token, | ||||
|                     ip=ip_addr, | ||||
|  | ||||
| @ -1,5 +1,6 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2014-2016 OpenMarket Ltd | ||||
| # Copyright 2017 Vector Creations Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| @ -44,6 +45,7 @@ class JoinRules(object): | ||||
| class LoginType(object): | ||||
|     PASSWORD = u"m.login.password" | ||||
|     EMAIL_IDENTITY = u"m.login.email.identity" | ||||
|     MSISDN = u"m.login.msisdn" | ||||
|     RECAPTCHA = u"m.login.recaptcha" | ||||
|     DUMMY = u"m.login.dummy" | ||||
| 
 | ||||
|  | ||||
| @ -15,6 +15,7 @@ | ||||
| 
 | ||||
| """Contains exceptions and error codes.""" | ||||
| 
 | ||||
| import json | ||||
| import logging | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| @ -50,27 +51,35 @@ class Codes(object): | ||||
| 
 | ||||
| 
 | ||||
| class CodeMessageException(RuntimeError): | ||||
|     """An exception with integer code and message string attributes.""" | ||||
|     """An exception with integer code and message string attributes. | ||||
| 
 | ||||
|     Attributes: | ||||
|         code (int): HTTP error code | ||||
|         msg (str): string describing the error | ||||
|     """ | ||||
|     def __init__(self, code, msg): | ||||
|         super(CodeMessageException, self).__init__("%d: %s" % (code, msg)) | ||||
|         self.code = code | ||||
|         self.msg = msg | ||||
|         self.response_code_message = None | ||||
| 
 | ||||
|     def error_dict(self): | ||||
|         return cs_error(self.msg) | ||||
| 
 | ||||
| 
 | ||||
| class SynapseError(CodeMessageException): | ||||
|     """A base error which can be caught for all synapse events.""" | ||||
|     """A base exception type for matrix errors which have an errcode and error | ||||
|     message (as well as an HTTP status code). | ||||
| 
 | ||||
|     Attributes: | ||||
|         errcode (str): Matrix error code e.g 'M_FORBIDDEN' | ||||
|     """ | ||||
|     def __init__(self, code, msg, errcode=Codes.UNKNOWN): | ||||
|         """Constructs a synapse error. | ||||
| 
 | ||||
|         Args: | ||||
|             code (int): The integer error code (an HTTP response code) | ||||
|             msg (str): The human-readable error message. | ||||
|             err (str): The error code e.g 'M_FORBIDDEN' | ||||
|             errcode (str): The matrix error code e.g 'M_FORBIDDEN' | ||||
|         """ | ||||
|         super(SynapseError, self).__init__(code, msg) | ||||
|         self.errcode = errcode | ||||
| @ -81,6 +90,39 @@ class SynapseError(CodeMessageException): | ||||
|             self.errcode, | ||||
|         ) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def from_http_response_exception(cls, err): | ||||
|         """Make a SynapseError based on an HTTPResponseException | ||||
| 
 | ||||
|         This is useful when a proxied request has failed, and we need to | ||||
|         decide how to map the failure onto a matrix error to send back to the | ||||
|         client. | ||||
| 
 | ||||
|         An attempt is made to parse the body of the http response as a matrix | ||||
|         error. If that succeeds, the errcode and error message from the body | ||||
|         are used as the errcode and error message in the new synapse error. | ||||
| 
 | ||||
|         Otherwise, the errcode is set to M_UNKNOWN, and the error message is | ||||
|         set to the reason code from the HTTP response. | ||||
| 
 | ||||
|         Args: | ||||
|             err (HttpResponseException): | ||||
| 
 | ||||
|         Returns: | ||||
|             SynapseError: | ||||
|         """ | ||||
|         # try to parse the body as json, to get better errcode/msg, but | ||||
|         # default to M_UNKNOWN with the HTTP status as the error text | ||||
|         try: | ||||
|             j = json.loads(err.response) | ||||
|         except ValueError: | ||||
|             j = {} | ||||
|         errcode = j.get('errcode', Codes.UNKNOWN) | ||||
|         errmsg = j.get('error', err.msg) | ||||
| 
 | ||||
|         res = SynapseError(err.code, errmsg, errcode) | ||||
|         return res | ||||
| 
 | ||||
| 
 | ||||
| class RegistrationError(SynapseError): | ||||
|     """An error raised when a registration event fails.""" | ||||
| @ -106,13 +148,11 @@ class UnrecognizedRequestError(SynapseError): | ||||
| 
 | ||||
| class NotFoundError(SynapseError): | ||||
|     """An error indicating we can't find the thing you asked for""" | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         if "errcode" not in kwargs: | ||||
|             kwargs["errcode"] = Codes.NOT_FOUND | ||||
|     def __init__(self, msg="Not found", errcode=Codes.NOT_FOUND): | ||||
|         super(NotFoundError, self).__init__( | ||||
|             404, | ||||
|             "Not found", | ||||
|             **kwargs | ||||
|             msg, | ||||
|             errcode=errcode | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| @ -173,7 +213,6 @@ class LimitExceededError(SynapseError): | ||||
|                  errcode=Codes.LIMIT_EXCEEDED): | ||||
|         super(LimitExceededError, self).__init__(code, msg, errcode) | ||||
|         self.retry_after_ms = retry_after_ms | ||||
|         self.response_code_message = "Too Many Requests" | ||||
| 
 | ||||
|     def error_dict(self): | ||||
|         return cs_error( | ||||
| @ -243,6 +282,19 @@ class FederationError(RuntimeError): | ||||
| 
 | ||||
| 
 | ||||
| class HttpResponseException(CodeMessageException): | ||||
|     """ | ||||
|     Represents an HTTP-level failure of an outbound request | ||||
| 
 | ||||
|     Attributes: | ||||
|         response (str): body of response | ||||
|     """ | ||||
|     def __init__(self, code, msg, response): | ||||
|         self.response = response | ||||
|         """ | ||||
| 
 | ||||
|         Args: | ||||
|             code (int): HTTP status code | ||||
|             msg (str): reason phrase from HTTP response status line | ||||
|             response (str): body of response | ||||
|         """ | ||||
|         super(HttpResponseException, self).__init__(code, msg) | ||||
|         self.response = response | ||||
|  | ||||
| @ -13,11 +13,174 @@ | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| from synapse.api.errors import SynapseError | ||||
| from synapse.storage.presence import UserPresenceState | ||||
| from synapse.types import UserID, RoomID | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| import ujson as json | ||||
| import jsonschema | ||||
| from jsonschema import FormatChecker | ||||
| 
 | ||||
| FILTER_SCHEMA = { | ||||
|     "additionalProperties": False, | ||||
|     "type": "object", | ||||
|     "properties": { | ||||
|         "limit": { | ||||
|             "type": "number" | ||||
|         }, | ||||
|         "senders": { | ||||
|             "$ref": "#/definitions/user_id_array" | ||||
|         }, | ||||
|         "not_senders": { | ||||
|             "$ref": "#/definitions/user_id_array" | ||||
|         }, | ||||
|         # TODO: We don't limit event type values but we probably should... | ||||
|         # check types are valid event types | ||||
|         "types": { | ||||
|             "type": "array", | ||||
|             "items": { | ||||
|                 "type": "string" | ||||
|             } | ||||
|         }, | ||||
|         "not_types": { | ||||
|             "type": "array", | ||||
|             "items": { | ||||
|                 "type": "string" | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| ROOM_FILTER_SCHEMA = { | ||||
|     "additionalProperties": False, | ||||
|     "type": "object", | ||||
|     "properties": { | ||||
|         "not_rooms": { | ||||
|             "$ref": "#/definitions/room_id_array" | ||||
|         }, | ||||
|         "rooms": { | ||||
|             "$ref": "#/definitions/room_id_array" | ||||
|         }, | ||||
|         "ephemeral": { | ||||
|             "$ref": "#/definitions/room_event_filter" | ||||
|         }, | ||||
|         "include_leave": { | ||||
|             "type": "boolean" | ||||
|         }, | ||||
|         "state": { | ||||
|             "$ref": "#/definitions/room_event_filter" | ||||
|         }, | ||||
|         "timeline": { | ||||
|             "$ref": "#/definitions/room_event_filter" | ||||
|         }, | ||||
|         "account_data": { | ||||
|             "$ref": "#/definitions/room_event_filter" | ||||
|         }, | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| ROOM_EVENT_FILTER_SCHEMA = { | ||||
|     "additionalProperties": False, | ||||
|     "type": "object", | ||||
|     "properties": { | ||||
|         "limit": { | ||||
|             "type": "number" | ||||
|         }, | ||||
|         "senders": { | ||||
|             "$ref": "#/definitions/user_id_array" | ||||
|         }, | ||||
|         "not_senders": { | ||||
|             "$ref": "#/definitions/user_id_array" | ||||
|         }, | ||||
|         "types": { | ||||
|             "type": "array", | ||||
|             "items": { | ||||
|                 "type": "string" | ||||
|             } | ||||
|         }, | ||||
|         "not_types": { | ||||
|             "type": "array", | ||||
|             "items": { | ||||
|                 "type": "string" | ||||
|             } | ||||
|         }, | ||||
|         "rooms": { | ||||
|             "$ref": "#/definitions/room_id_array" | ||||
|         }, | ||||
|         "not_rooms": { | ||||
|             "$ref": "#/definitions/room_id_array" | ||||
|         }, | ||||
|         "contains_url": { | ||||
|             "type": "boolean" | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| USER_ID_ARRAY_SCHEMA = { | ||||
|     "type": "array", | ||||
|     "items": { | ||||
|         "type": "string", | ||||
|         "format": "matrix_user_id" | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| ROOM_ID_ARRAY_SCHEMA = { | ||||
|     "type": "array", | ||||
|     "items": { | ||||
|         "type": "string", | ||||
|         "format": "matrix_room_id" | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| USER_FILTER_SCHEMA = { | ||||
|     "$schema": "http://json-schema.org/draft-04/schema#", | ||||
|     "description": "schema for a Sync filter", | ||||
|     "type": "object", | ||||
|     "definitions": { | ||||
|         "room_id_array": ROOM_ID_ARRAY_SCHEMA, | ||||
|         "user_id_array": USER_ID_ARRAY_SCHEMA, | ||||
|         "filter": FILTER_SCHEMA, | ||||
|         "room_filter": ROOM_FILTER_SCHEMA, | ||||
|         "room_event_filter": ROOM_EVENT_FILTER_SCHEMA | ||||
|     }, | ||||
|     "properties": { | ||||
|         "presence": { | ||||
|             "$ref": "#/definitions/filter" | ||||
|         }, | ||||
|         "account_data": { | ||||
|             "$ref": "#/definitions/filter" | ||||
|         }, | ||||
|         "room": { | ||||
|             "$ref": "#/definitions/room_filter" | ||||
|         }, | ||||
|         "event_format": { | ||||
|             "type": "string", | ||||
|             "enum": ["client", "federation"] | ||||
|         }, | ||||
|         "event_fields": { | ||||
|             "type": "array", | ||||
|             "items": { | ||||
|                 "type": "string", | ||||
|                 # Don't allow '\\' in event field filters. This makes matching | ||||
|                 # events a lot easier as we can then use a negative lookbehind | ||||
|                 # assertion to split '\.' If we allowed \\ then it would | ||||
|                 # incorrectly split '\\.' See synapse.events.utils.serialize_event | ||||
|                 "pattern": "^((?!\\\).)*$" | ||||
|             } | ||||
|         } | ||||
|     }, | ||||
|     "additionalProperties": False | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| @FormatChecker.cls_checks('matrix_room_id') | ||||
| def matrix_room_id_validator(room_id_str): | ||||
|     return RoomID.from_string(room_id_str) | ||||
| 
 | ||||
| 
 | ||||
| @FormatChecker.cls_checks('matrix_user_id') | ||||
| def matrix_user_id_validator(user_id_str): | ||||
|     return UserID.from_string(user_id_str) | ||||
| 
 | ||||
| 
 | ||||
| class Filtering(object): | ||||
| @ -52,98 +215,11 @@ class Filtering(object): | ||||
|         # NB: Filters are the complete json blobs. "Definitions" are an | ||||
|         # individual top-level key e.g. public_user_data. Filters are made of | ||||
|         # many definitions. | ||||
| 
 | ||||
|         top_level_definitions = [ | ||||
|             "presence", "account_data" | ||||
|         ] | ||||
| 
 | ||||
|         room_level_definitions = [ | ||||
|             "state", "timeline", "ephemeral", "account_data" | ||||
|         ] | ||||
| 
 | ||||
|         for key in top_level_definitions: | ||||
|             if key in user_filter_json: | ||||
|                 self._check_definition(user_filter_json[key]) | ||||
| 
 | ||||
|         if "room" in user_filter_json: | ||||
|             self._check_definition_room_lists(user_filter_json["room"]) | ||||
|             for key in room_level_definitions: | ||||
|                 if key in user_filter_json["room"]: | ||||
|                     self._check_definition(user_filter_json["room"][key]) | ||||
| 
 | ||||
|         if "event_fields" in user_filter_json: | ||||
|             if type(user_filter_json["event_fields"]) != list: | ||||
|                 raise SynapseError(400, "event_fields must be a list of strings") | ||||
|             for field in user_filter_json["event_fields"]: | ||||
|                 if not isinstance(field, basestring): | ||||
|                     raise SynapseError(400, "Event field must be a string") | ||||
|                 # Don't allow '\\' in event field filters. This makes matching | ||||
|                 # events a lot easier as we can then use a negative lookbehind | ||||
|                 # assertion to split '\.' If we allowed \\ then it would | ||||
|                 # incorrectly split '\\.' See synapse.events.utils.serialize_event | ||||
|                 if r'\\' in field: | ||||
|                     raise SynapseError( | ||||
|                         400, r'The escape character \ cannot itself be escaped' | ||||
|                     ) | ||||
| 
 | ||||
|     def _check_definition_room_lists(self, definition): | ||||
|         """Check that "rooms" and "not_rooms" are lists of room ids if they | ||||
|         are present | ||||
| 
 | ||||
|         Args: | ||||
|             definition(dict): The filter definition | ||||
|         Raises: | ||||
|             SynapseError: If there was a problem with this definition. | ||||
|         """ | ||||
|         # check rooms are valid room IDs | ||||
|         room_id_keys = ["rooms", "not_rooms"] | ||||
|         for key in room_id_keys: | ||||
|             if key in definition: | ||||
|                 if type(definition[key]) != list: | ||||
|                     raise SynapseError(400, "Expected %s to be a list." % key) | ||||
|                 for room_id in definition[key]: | ||||
|                     RoomID.from_string(room_id) | ||||
| 
 | ||||
|     def _check_definition(self, definition): | ||||
|         """Check if the provided definition is valid. | ||||
| 
 | ||||
|         This inspects not only the types but also the values to make sure they | ||||
|         make sense. | ||||
| 
 | ||||
|         Args: | ||||
|             definition(dict): The filter definition | ||||
|         Raises: | ||||
|             SynapseError: If there was a problem with this definition. | ||||
|         """ | ||||
|         # NB: Filters are the complete json blobs. "Definitions" are an | ||||
|         # individual top-level key e.g. public_user_data. Filters are made of | ||||
|         # many definitions. | ||||
|         if type(definition) != dict: | ||||
|             raise SynapseError( | ||||
|                 400, "Expected JSON object, not %s" % (definition,) | ||||
|             ) | ||||
| 
 | ||||
|         self._check_definition_room_lists(definition) | ||||
| 
 | ||||
|         # check senders are valid user IDs | ||||
|         user_id_keys = ["senders", "not_senders"] | ||||
|         for key in user_id_keys: | ||||
|             if key in definition: | ||||
|                 if type(definition[key]) != list: | ||||
|                     raise SynapseError(400, "Expected %s to be a list." % key) | ||||
|                 for user_id in definition[key]: | ||||
|                     UserID.from_string(user_id) | ||||
| 
 | ||||
|         # TODO: We don't limit event type values but we probably should... | ||||
|         # check types are valid event types | ||||
|         event_keys = ["types", "not_types"] | ||||
|         for key in event_keys: | ||||
|             if key in definition: | ||||
|                 if type(definition[key]) != list: | ||||
|                     raise SynapseError(400, "Expected %s to be a list." % key) | ||||
|                 for event_type in definition[key]: | ||||
|                     if not isinstance(event_type, basestring): | ||||
|                         raise SynapseError(400, "Event type should be a string") | ||||
|         try: | ||||
|             jsonschema.validate(user_filter_json, USER_FILTER_SCHEMA, | ||||
|                                 format_checker=FormatChecker()) | ||||
|         except jsonschema.ValidationError as e: | ||||
|             raise SynapseError(400, e.message) | ||||
| 
 | ||||
| 
 | ||||
| class FilterCollection(object): | ||||
| @ -253,19 +329,35 @@ class Filter(object): | ||||
|         Returns: | ||||
|             bool: True if the event matches | ||||
|         """ | ||||
|         sender = event.get("sender", None) | ||||
|         if not sender: | ||||
|             # Presence events have their 'sender' in content.user_id | ||||
|             content = event.get("content") | ||||
|             # account_data has been allowed to have non-dict content, so check type first | ||||
|             if isinstance(content, dict): | ||||
|                 sender = content.get("user_id") | ||||
|         # We usually get the full "events" as dictionaries coming through, | ||||
|         # except for presence which actually gets passed around as its own | ||||
|         # namedtuple type. | ||||
|         if isinstance(event, UserPresenceState): | ||||
|             sender = event.user_id | ||||
|             room_id = None | ||||
|             ev_type = "m.presence" | ||||
|             is_url = False | ||||
|         else: | ||||
|             sender = event.get("sender", None) | ||||
|             if not sender: | ||||
|                 # Presence events had their 'sender' in content.user_id, but are | ||||
|                 # now handled above. We don't know if anything else uses this | ||||
|                 # form. TODO: Check this and probably remove it. | ||||
|                 content = event.get("content") | ||||
|                 # account_data has been allowed to have non-dict content, so | ||||
|                 # check type first | ||||
|                 if isinstance(content, dict): | ||||
|                     sender = content.get("user_id") | ||||
| 
 | ||||
|             room_id = event.get("room_id", None) | ||||
|             ev_type = event.get("type", None) | ||||
|             is_url = "url" in event.get("content", {}) | ||||
| 
 | ||||
|         return self.check_fields( | ||||
|             event.get("room_id", None), | ||||
|             room_id, | ||||
|             sender, | ||||
|             event.get("type", None), | ||||
|             "url" in event.get("content", {}) | ||||
|             ev_type, | ||||
|             is_url, | ||||
|         ) | ||||
| 
 | ||||
|     def check_fields(self, room_id, sender, event_type, contains_url): | ||||
|  | ||||
| @ -29,7 +29,7 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto | ||||
| from synapse.storage.engines import create_engine | ||||
| from synapse.util.async import sleep | ||||
| from synapse.util.httpresourcetree import create_resource_tree | ||||
| from synapse.util.logcontext import LoggingContext | ||||
| from synapse.util.logcontext import LoggingContext, PreserveLoggingContext | ||||
| from synapse.util.manhole import manhole | ||||
| from synapse.util.rlimit import change_resource_limit | ||||
| from synapse.util.versionstring import get_version_string | ||||
| @ -157,7 +157,7 @@ def start(config_options): | ||||
| 
 | ||||
|     assert config.worker_app == "synapse.app.appservice" | ||||
| 
 | ||||
|     setup_logging(config.worker_log_config, config.worker_log_file) | ||||
|     setup_logging(config, use_worker_options=True) | ||||
| 
 | ||||
|     events.USE_FROZEN_DICTS = config.use_frozen_dicts | ||||
| 
 | ||||
| @ -187,7 +187,11 @@ def start(config_options): | ||||
|     ps.start_listening(config.worker_listeners) | ||||
| 
 | ||||
|     def run(): | ||||
|         with LoggingContext("run"): | ||||
|         # make sure that we run the reactor with the sentinel log context, | ||||
|         # otherwise other PreserveLoggingContext instances will get confused | ||||
|         # and complain when they see the logcontext arbitrarily swapping | ||||
|         # between the sentinel and `run` logcontexts. | ||||
|         with PreserveLoggingContext(): | ||||
|             logger.info("Running") | ||||
|             change_resource_limit(config.soft_file_limit) | ||||
|             if config.gc_thresholds: | ||||
|  | ||||
| @ -29,13 +29,14 @@ from synapse.replication.slave.storage.keys import SlavedKeyStore | ||||
| from synapse.replication.slave.storage.room import RoomStore | ||||
| from synapse.replication.slave.storage.directory import DirectoryStore | ||||
| from synapse.replication.slave.storage.registration import SlavedRegistrationStore | ||||
| from synapse.replication.slave.storage.transactions import TransactionStore | ||||
| from synapse.rest.client.v1.room import PublicRoomListRestServlet | ||||
| from synapse.server import HomeServer | ||||
| from synapse.storage.client_ips import ClientIpStore | ||||
| from synapse.storage.engines import create_engine | ||||
| from synapse.util.async import sleep | ||||
| from synapse.util.httpresourcetree import create_resource_tree | ||||
| from synapse.util.logcontext import LoggingContext | ||||
| from synapse.util.logcontext import LoggingContext, PreserveLoggingContext | ||||
| from synapse.util.manhole import manhole | ||||
| from synapse.util.rlimit import change_resource_limit | ||||
| from synapse.util.versionstring import get_version_string | ||||
| @ -63,6 +64,7 @@ class ClientReaderSlavedStore( | ||||
|     DirectoryStore, | ||||
|     SlavedApplicationServiceStore, | ||||
|     SlavedRegistrationStore, | ||||
|     TransactionStore, | ||||
|     BaseSlavedStore, | ||||
|     ClientIpStore,  # After BaseSlavedStore because the constructor is different | ||||
| ): | ||||
| @ -171,7 +173,7 @@ def start(config_options): | ||||
| 
 | ||||
|     assert config.worker_app == "synapse.app.client_reader" | ||||
| 
 | ||||
|     setup_logging(config.worker_log_config, config.worker_log_file) | ||||
|     setup_logging(config, use_worker_options=True) | ||||
| 
 | ||||
|     events.USE_FROZEN_DICTS = config.use_frozen_dicts | ||||
| 
 | ||||
| @ -193,7 +195,11 @@ def start(config_options): | ||||
|     ss.start_listening(config.worker_listeners) | ||||
| 
 | ||||
|     def run(): | ||||
|         with LoggingContext("run"): | ||||
|         # make sure that we run the reactor with the sentinel log context, | ||||
|         # otherwise other PreserveLoggingContext instances will get confused | ||||
|         # and complain when they see the logcontext arbitrarily swapping | ||||
|         # between the sentinel and `run` logcontexts. | ||||
|         with PreserveLoggingContext(): | ||||
|             logger.info("Running") | ||||
|             change_resource_limit(config.soft_file_limit) | ||||
|             if config.gc_thresholds: | ||||
|  | ||||
| @ -31,7 +31,7 @@ from synapse.server import HomeServer | ||||
| from synapse.storage.engines import create_engine | ||||
| from synapse.util.async import sleep | ||||
| from synapse.util.httpresourcetree import create_resource_tree | ||||
| from synapse.util.logcontext import LoggingContext | ||||
| from synapse.util.logcontext import LoggingContext, PreserveLoggingContext | ||||
| from synapse.util.manhole import manhole | ||||
| from synapse.util.rlimit import change_resource_limit | ||||
| from synapse.util.versionstring import get_version_string | ||||
| @ -162,7 +162,7 @@ def start(config_options): | ||||
| 
 | ||||
|     assert config.worker_app == "synapse.app.federation_reader" | ||||
| 
 | ||||
|     setup_logging(config.worker_log_config, config.worker_log_file) | ||||
|     setup_logging(config, use_worker_options=True) | ||||
| 
 | ||||
|     events.USE_FROZEN_DICTS = config.use_frozen_dicts | ||||
| 
 | ||||
| @ -184,7 +184,11 @@ def start(config_options): | ||||
|     ss.start_listening(config.worker_listeners) | ||||
| 
 | ||||
|     def run(): | ||||
|         with LoggingContext("run"): | ||||
|         # make sure that we run the reactor with the sentinel log context, | ||||
|         # otherwise other PreserveLoggingContext instances will get confused | ||||
|         # and complain when they see the logcontext arbitrarily swapping | ||||
|         # between the sentinel and `run` logcontexts. | ||||
|         with PreserveLoggingContext(): | ||||
|             logger.info("Running") | ||||
|             change_resource_limit(config.soft_file_limit) | ||||
|             if config.gc_thresholds: | ||||
|  | ||||
| @ -35,7 +35,7 @@ from synapse.storage.engines import create_engine | ||||
| from synapse.storage.presence import UserPresenceState | ||||
| from synapse.util.async import sleep | ||||
| from synapse.util.httpresourcetree import create_resource_tree | ||||
| from synapse.util.logcontext import LoggingContext | ||||
| from synapse.util.logcontext import LoggingContext, PreserveLoggingContext | ||||
| from synapse.util.manhole import manhole | ||||
| from synapse.util.rlimit import change_resource_limit | ||||
| from synapse.util.versionstring import get_version_string | ||||
| @ -160,7 +160,7 @@ def start(config_options): | ||||
| 
 | ||||
|     assert config.worker_app == "synapse.app.federation_sender" | ||||
| 
 | ||||
|     setup_logging(config.worker_log_config, config.worker_log_file) | ||||
|     setup_logging(config, use_worker_options=True) | ||||
| 
 | ||||
|     events.USE_FROZEN_DICTS = config.use_frozen_dicts | ||||
| 
 | ||||
| @ -193,7 +193,11 @@ def start(config_options): | ||||
|     ps.start_listening(config.worker_listeners) | ||||
| 
 | ||||
|     def run(): | ||||
|         with LoggingContext("run"): | ||||
|         # make sure that we run the reactor with the sentinel log context, | ||||
|         # otherwise other PreserveLoggingContext instances will get confused | ||||
|         # and complain when they see the logcontext arbitrarily swapping | ||||
|         # between the sentinel and `run` logcontexts. | ||||
|         with PreserveLoggingContext(): | ||||
|             logger.info("Running") | ||||
|             change_resource_limit(config.soft_file_limit) | ||||
|             if config.gc_thresholds: | ||||
|  | ||||
| @ -20,6 +20,8 @@ import gc | ||||
| import logging | ||||
| import os | ||||
| import sys | ||||
| 
 | ||||
| import synapse.config.logger | ||||
| from synapse.config._base import ConfigError | ||||
| 
 | ||||
| from synapse.python_dependencies import ( | ||||
| @ -50,7 +52,7 @@ from synapse.api.urls import ( | ||||
| ) | ||||
| from synapse.config.homeserver import HomeServerConfig | ||||
| from synapse.crypto import context_factory | ||||
| from synapse.util.logcontext import LoggingContext | ||||
| from synapse.util.logcontext import LoggingContext, PreserveLoggingContext | ||||
| from synapse.metrics import register_memory_metrics, get_metrics_for | ||||
| from synapse.metrics.resource import MetricsResource, METRICS_PREFIX | ||||
| from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX | ||||
| @ -286,7 +288,7 @@ def setup(config_options): | ||||
|         # generating config files and shouldn't try to continue. | ||||
|         sys.exit(0) | ||||
| 
 | ||||
|     config.setup_logging() | ||||
|     synapse.config.logger.setup_logging(config, use_worker_options=False) | ||||
| 
 | ||||
|     # check any extra requirements we have now we have a config | ||||
|     check_requirements(config) | ||||
| @ -454,7 +456,12 @@ def run(hs): | ||||
|     def in_thread(): | ||||
|         # Uncomment to enable tracing of log context changes. | ||||
|         # sys.settrace(logcontext_tracer) | ||||
|         with LoggingContext("run"): | ||||
| 
 | ||||
|         # make sure that we run the reactor with the sentinel log context, | ||||
|         # otherwise other PreserveLoggingContext instances will get confused | ||||
|         # and complain when they see the logcontext arbitrarily swapping | ||||
|         # between the sentinel and `run` logcontexts. | ||||
|         with PreserveLoggingContext(): | ||||
|             change_resource_limit(hs.config.soft_file_limit) | ||||
|             if hs.config.gc_thresholds: | ||||
|                 gc.set_threshold(*hs.config.gc_thresholds) | ||||
|  | ||||
| @ -24,6 +24,7 @@ from synapse.metrics.resource import MetricsResource, METRICS_PREFIX | ||||
| from synapse.replication.slave.storage._base import BaseSlavedStore | ||||
| from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore | ||||
| from synapse.replication.slave.storage.registration import SlavedRegistrationStore | ||||
| from synapse.replication.slave.storage.transactions import TransactionStore | ||||
| from synapse.rest.media.v0.content_repository import ContentRepoResource | ||||
| from synapse.rest.media.v1.media_repository import MediaRepositoryResource | ||||
| from synapse.server import HomeServer | ||||
| @ -32,7 +33,7 @@ from synapse.storage.engines import create_engine | ||||
| from synapse.storage.media_repository import MediaRepositoryStore | ||||
| from synapse.util.async import sleep | ||||
| from synapse.util.httpresourcetree import create_resource_tree | ||||
| from synapse.util.logcontext import LoggingContext | ||||
| from synapse.util.logcontext import LoggingContext, PreserveLoggingContext | ||||
| from synapse.util.manhole import manhole | ||||
| from synapse.util.rlimit import change_resource_limit | ||||
| from synapse.util.versionstring import get_version_string | ||||
| @ -59,6 +60,7 @@ logger = logging.getLogger("synapse.app.media_repository") | ||||
| class MediaRepositorySlavedStore( | ||||
|     SlavedApplicationServiceStore, | ||||
|     SlavedRegistrationStore, | ||||
|     TransactionStore, | ||||
|     BaseSlavedStore, | ||||
|     MediaRepositoryStore, | ||||
|     ClientIpStore, | ||||
| @ -168,7 +170,7 @@ def start(config_options): | ||||
| 
 | ||||
|     assert config.worker_app == "synapse.app.media_repository" | ||||
| 
 | ||||
|     setup_logging(config.worker_log_config, config.worker_log_file) | ||||
|     setup_logging(config, use_worker_options=True) | ||||
| 
 | ||||
|     events.USE_FROZEN_DICTS = config.use_frozen_dicts | ||||
| 
 | ||||
| @ -190,7 +192,11 @@ def start(config_options): | ||||
|     ss.start_listening(config.worker_listeners) | ||||
| 
 | ||||
|     def run(): | ||||
|         with LoggingContext("run"): | ||||
|         # make sure that we run the reactor with the sentinel log context, | ||||
|         # otherwise other PreserveLoggingContext instances will get confused | ||||
|         # and complain when they see the logcontext arbitrarily swapping | ||||
|         # between the sentinel and `run` logcontexts. | ||||
|         with PreserveLoggingContext(): | ||||
|             logger.info("Running") | ||||
|             change_resource_limit(config.soft_file_limit) | ||||
|             if config.gc_thresholds: | ||||
|  | ||||
| @ -31,7 +31,8 @@ from synapse.storage.engines import create_engine | ||||
| from synapse.storage import DataStore | ||||
| from synapse.util.async import sleep | ||||
| from synapse.util.httpresourcetree import create_resource_tree | ||||
| from synapse.util.logcontext import LoggingContext, preserve_fn | ||||
| from synapse.util.logcontext import LoggingContext, preserve_fn, \ | ||||
|     PreserveLoggingContext | ||||
| from synapse.util.manhole import manhole | ||||
| from synapse.util.rlimit import change_resource_limit | ||||
| from synapse.util.versionstring import get_version_string | ||||
| @ -245,7 +246,7 @@ def start(config_options): | ||||
| 
 | ||||
|     assert config.worker_app == "synapse.app.pusher" | ||||
| 
 | ||||
|     setup_logging(config.worker_log_config, config.worker_log_file) | ||||
|     setup_logging(config, use_worker_options=True) | ||||
| 
 | ||||
|     events.USE_FROZEN_DICTS = config.use_frozen_dicts | ||||
| 
 | ||||
| @ -275,7 +276,11 @@ def start(config_options): | ||||
|     ps.start_listening(config.worker_listeners) | ||||
| 
 | ||||
|     def run(): | ||||
|         with LoggingContext("run"): | ||||
|         # make sure that we run the reactor with the sentinel log context, | ||||
|         # otherwise other PreserveLoggingContext instances will get confused | ||||
|         # and complain when they see the logcontext arbitrarily swapping | ||||
|         # between the sentinel and `run` logcontexts. | ||||
|         with PreserveLoggingContext(): | ||||
|             logger.info("Running") | ||||
|             change_resource_limit(config.soft_file_limit) | ||||
|             if config.gc_thresholds: | ||||
|  | ||||
| @ -20,7 +20,6 @@ from synapse.api.constants import EventTypes, PresenceState | ||||
| from synapse.config._base import ConfigError | ||||
| from synapse.config.homeserver import HomeServerConfig | ||||
| from synapse.config.logger import setup_logging | ||||
| from synapse.events import FrozenEvent | ||||
| from synapse.handlers.presence import PresenceHandler | ||||
| from synapse.http.site import SynapseSite | ||||
| from synapse.http.server import JsonResource | ||||
| @ -48,7 +47,8 @@ from synapse.storage.presence import PresenceStore, UserPresenceState | ||||
| from synapse.storage.roommember import RoomMemberStore | ||||
| from synapse.util.async import sleep | ||||
| from synapse.util.httpresourcetree import create_resource_tree | ||||
| from synapse.util.logcontext import LoggingContext, preserve_fn | ||||
| from synapse.util.logcontext import LoggingContext, preserve_fn, \ | ||||
|     PreserveLoggingContext | ||||
| from synapse.util.manhole import manhole | ||||
| from synapse.util.rlimit import change_resource_limit | ||||
| from synapse.util.stringutils import random_string | ||||
| @ -399,8 +399,7 @@ class SynchrotronServer(HomeServer): | ||||
|                 position = row[position_index] | ||||
|                 user_id = row[user_index] | ||||
| 
 | ||||
|                 rooms = yield store.get_rooms_for_user(user_id) | ||||
|                 room_ids = [r.room_id for r in rooms] | ||||
|                 room_ids = yield store.get_rooms_for_user(user_id) | ||||
| 
 | ||||
|                 notifier.on_new_event( | ||||
|                     "device_list_key", position, rooms=room_ids, | ||||
| @ -411,11 +410,16 @@ class SynchrotronServer(HomeServer): | ||||
|             stream = result.get("events") | ||||
|             if stream: | ||||
|                 max_position = stream["position"] | ||||
| 
 | ||||
|                 event_map = yield store.get_events([row[1] for row in stream["rows"]]) | ||||
| 
 | ||||
|                 for row in stream["rows"]: | ||||
|                     position = row[0] | ||||
|                     internal = json.loads(row[1]) | ||||
|                     event_json = json.loads(row[2]) | ||||
|                     event = FrozenEvent(event_json, internal_metadata_dict=internal) | ||||
|                     event_id = row[1] | ||||
|                     event = event_map.get(event_id, None) | ||||
|                     if not event: | ||||
|                         continue | ||||
| 
 | ||||
|                     extra_users = () | ||||
|                     if event.type == EventTypes.Member: | ||||
|                         extra_users = (event.state_key,) | ||||
| @ -478,7 +482,7 @@ def start(config_options): | ||||
| 
 | ||||
|     assert config.worker_app == "synapse.app.synchrotron" | ||||
| 
 | ||||
|     setup_logging(config.worker_log_config, config.worker_log_file) | ||||
|     setup_logging(config, use_worker_options=True) | ||||
| 
 | ||||
|     synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts | ||||
| 
 | ||||
| @ -497,7 +501,11 @@ def start(config_options): | ||||
|     ss.start_listening(config.worker_listeners) | ||||
| 
 | ||||
|     def run(): | ||||
|         with LoggingContext("run"): | ||||
|         # make sure that we run the reactor with the sentinel log context, | ||||
|         # otherwise other PreserveLoggingContext instances will get confused | ||||
|         # and complain when they see the logcontext arbitrarily swapping | ||||
|         # between the sentinel and `run` logcontexts. | ||||
|         with PreserveLoggingContext(): | ||||
|             logger.info("Running") | ||||
|             change_resource_limit(config.soft_file_limit) | ||||
|             if config.gc_thresholds: | ||||
|  | ||||
| @ -23,14 +23,27 @@ import signal | ||||
| import subprocess | ||||
| import sys | ||||
| import yaml | ||||
| import errno | ||||
| import time | ||||
| 
 | ||||
| SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"] | ||||
| 
 | ||||
| GREEN = "\x1b[1;32m" | ||||
| YELLOW = "\x1b[1;33m" | ||||
| RED = "\x1b[1;31m" | ||||
| NORMAL = "\x1b[m" | ||||
| 
 | ||||
| 
 | ||||
| def pid_running(pid): | ||||
|     try: | ||||
|         os.kill(pid, 0) | ||||
|         return True | ||||
|     except OSError, err: | ||||
|         if err.errno == errno.EPERM: | ||||
|             return True | ||||
|         return False | ||||
| 
 | ||||
| 
 | ||||
| def write(message, colour=NORMAL, stream=sys.stdout): | ||||
|     if colour == NORMAL: | ||||
|         stream.write(message + "\n") | ||||
| @ -38,6 +51,11 @@ def write(message, colour=NORMAL, stream=sys.stdout): | ||||
|         stream.write(colour + message + NORMAL + "\n") | ||||
| 
 | ||||
| 
 | ||||
| def abort(message, colour=RED, stream=sys.stderr): | ||||
|     write(message, colour, stream) | ||||
|     sys.exit(1) | ||||
| 
 | ||||
| 
 | ||||
| def start(configfile): | ||||
|     write("Starting ...") | ||||
|     args = SYNAPSE | ||||
| @ -45,7 +63,8 @@ def start(configfile): | ||||
| 
 | ||||
|     try: | ||||
|         subprocess.check_call(args) | ||||
|         write("started synapse.app.homeserver(%r)" % (configfile,), colour=GREEN) | ||||
|         write("started synapse.app.homeserver(%r)" % | ||||
|               (configfile,), colour=GREEN) | ||||
|     except subprocess.CalledProcessError as e: | ||||
|         write( | ||||
|             "error starting (exit code: %d); see above for logs" % e.returncode, | ||||
| @ -76,8 +95,16 @@ def start_worker(app, configfile, worker_configfile): | ||||
| def stop(pidfile, app): | ||||
|     if os.path.exists(pidfile): | ||||
|         pid = int(open(pidfile).read()) | ||||
|         os.kill(pid, signal.SIGTERM) | ||||
|         write("stopped %s" % (app,), colour=GREEN) | ||||
|         try: | ||||
|             os.kill(pid, signal.SIGTERM) | ||||
|             write("stopped %s" % (app,), colour=GREEN) | ||||
|         except OSError, err: | ||||
|             if err.errno == errno.ESRCH: | ||||
|                 write("%s not running" % (app,), colour=YELLOW) | ||||
|             elif err.errno == errno.EPERM: | ||||
|                 abort("Cannot stop %s: Operation not permitted" % (app,)) | ||||
|             else: | ||||
|                 abort("Cannot stop %s: Unknown error" % (app,)) | ||||
| 
 | ||||
| 
 | ||||
| Worker = collections.namedtuple("Worker", [ | ||||
| @ -190,7 +217,19 @@ def main(): | ||||
|         if start_stop_synapse: | ||||
|             stop(pidfile, "synapse.app.homeserver") | ||||
| 
 | ||||
|         # TODO: Wait for synapse to actually shutdown before starting it again | ||||
|     # Wait for synapse to actually shutdown before starting it again | ||||
|     if action == "restart": | ||||
|         running_pids = [] | ||||
|         if start_stop_synapse and os.path.exists(pidfile): | ||||
|             running_pids.append(int(open(pidfile).read())) | ||||
|         for worker in workers: | ||||
|             if os.path.exists(worker.pidfile): | ||||
|                 running_pids.append(int(open(worker.pidfile).read())) | ||||
|         if len(running_pids) > 0: | ||||
|             write("Waiting for process to exit before restarting...") | ||||
|             for running_pid in running_pids: | ||||
|                 while pid_running(running_pid): | ||||
|                     time.sleep(0.2) | ||||
| 
 | ||||
|     if action == "start" or action == "restart": | ||||
|         if start_stop_synapse: | ||||
|  | ||||
| @ -45,7 +45,6 @@ handlers: | ||||
|     maxBytes: 104857600 | ||||
|     backupCount: 10 | ||||
|     filters: [context] | ||||
|     level: INFO | ||||
|   console: | ||||
|     class: logging.StreamHandler | ||||
|     formatter: precise | ||||
| @ -56,6 +55,8 @@ loggers: | ||||
|         level: INFO | ||||
| 
 | ||||
|     synapse.storage.SQL: | ||||
|         # beware: increasing this to DEBUG will make synapse log sensitive | ||||
|         # information such as access tokens. | ||||
|         level: INFO | ||||
| 
 | ||||
| root: | ||||
| @ -68,6 +69,7 @@ class LoggingConfig(Config): | ||||
| 
 | ||||
|     def read_config(self, config): | ||||
|         self.verbosity = config.get("verbose", 0) | ||||
|         self.no_redirect_stdio = config.get("no_redirect_stdio", False) | ||||
|         self.log_config = self.abspath(config.get("log_config")) | ||||
|         self.log_file = self.abspath(config.get("log_file")) | ||||
| 
 | ||||
| @ -77,10 +79,10 @@ class LoggingConfig(Config): | ||||
|             os.path.join(config_dir_path, server_name + ".log.config") | ||||
|         ) | ||||
|         return """ | ||||
|         # Logging verbosity level. | ||||
|         # Logging verbosity level. Ignored if log_config is specified. | ||||
|         verbose: 0 | ||||
| 
 | ||||
|         # File to write logging to | ||||
|         # File to write logging to. Ignored if log_config is specified. | ||||
|         log_file: "%(log_file)s" | ||||
| 
 | ||||
|         # A yaml python logging config file | ||||
| @ -90,6 +92,8 @@ class LoggingConfig(Config): | ||||
|     def read_arguments(self, args): | ||||
|         if args.verbose is not None: | ||||
|             self.verbosity = args.verbose | ||||
|         if args.no_redirect_stdio is not None: | ||||
|             self.no_redirect_stdio = args.no_redirect_stdio | ||||
|         if args.log_config is not None: | ||||
|             self.log_config = args.log_config | ||||
|         if args.log_file is not None: | ||||
| @ -99,16 +103,22 @@ class LoggingConfig(Config): | ||||
|         logging_group = parser.add_argument_group("logging") | ||||
|         logging_group.add_argument( | ||||
|             '-v', '--verbose', dest="verbose", action='count', | ||||
|             help="The verbosity level." | ||||
|             help="The verbosity level. Specify multiple times to increase " | ||||
|             "verbosity. (Ignored if --log-config is specified.)" | ||||
|         ) | ||||
|         logging_group.add_argument( | ||||
|             '-f', '--log-file', dest="log_file", | ||||
|             help="File to log to." | ||||
|             help="File to log to. (Ignored if --log-config is specified.)" | ||||
|         ) | ||||
|         logging_group.add_argument( | ||||
|             '--log-config', dest="log_config", default=None, | ||||
|             help="Python logging config file" | ||||
|         ) | ||||
|         logging_group.add_argument( | ||||
|             '-n', '--no-redirect-stdio', | ||||
|             action='store_true', default=None, | ||||
|             help="Do not redirect stdout/stderr to the log" | ||||
|         ) | ||||
| 
 | ||||
|     def generate_files(self, config): | ||||
|         log_config = config.get("log_config") | ||||
| @ -118,11 +128,22 @@ class LoggingConfig(Config): | ||||
|                     DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"]) | ||||
|                 ) | ||||
| 
 | ||||
|     def setup_logging(self): | ||||
|         setup_logging(self.log_config, self.log_file, self.verbosity) | ||||
| 
 | ||||
| def setup_logging(config, use_worker_options=False): | ||||
|     """ Set up python logging | ||||
| 
 | ||||
|     Args: | ||||
|         config (LoggingConfig | synapse.config.workers.WorkerConfig): | ||||
|             configuration data | ||||
| 
 | ||||
|         use_worker_options (bool): True to use 'worker_log_config' and | ||||
|             'worker_log_file' options instead of 'log_config' and 'log_file'. | ||||
|     """ | ||||
|     log_config = (config.worker_log_config if use_worker_options | ||||
|                   else config.log_config) | ||||
|     log_file = (config.worker_log_file if use_worker_options | ||||
|                 else config.log_file) | ||||
| 
 | ||||
| def setup_logging(log_config=None, log_file=None, verbosity=None): | ||||
|     log_format = ( | ||||
|         "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s" | ||||
|         " - %(message)s" | ||||
| @ -131,9 +152,9 @@ def setup_logging(log_config=None, log_file=None, verbosity=None): | ||||
| 
 | ||||
|         level = logging.INFO | ||||
|         level_for_storage = logging.INFO | ||||
|         if verbosity: | ||||
|         if config.verbosity: | ||||
|             level = logging.DEBUG | ||||
|             if verbosity > 1: | ||||
|             if config.verbosity > 1: | ||||
|                 level_for_storage = logging.DEBUG | ||||
| 
 | ||||
|         # FIXME: we need a logging.WARN for a -q quiet option | ||||
| @ -153,14 +174,6 @@ def setup_logging(log_config=None, log_file=None, verbosity=None): | ||||
|                 logger.info("Closing log file due to SIGHUP") | ||||
|                 handler.doRollover() | ||||
|                 logger.info("Opened new log file due to SIGHUP") | ||||
| 
 | ||||
|             # TODO(paul): obviously this is a terrible mechanism for | ||||
|             #   stealing SIGHUP, because it means no other part of synapse | ||||
|             #   can use it instead. If we want to catch SIGHUP anywhere | ||||
|             #   else as well, I'd suggest we find a nicer way to broadcast | ||||
|             #   it around. | ||||
|             if getattr(signal, "SIGHUP"): | ||||
|                 signal.signal(signal.SIGHUP, sighup) | ||||
|         else: | ||||
|             handler = logging.StreamHandler() | ||||
|         handler.setFormatter(formatter) | ||||
| @ -169,8 +182,25 @@ def setup_logging(log_config=None, log_file=None, verbosity=None): | ||||
| 
 | ||||
|         logger.addHandler(handler) | ||||
|     else: | ||||
|         with open(log_config, 'r') as f: | ||||
|             logging.config.dictConfig(yaml.load(f)) | ||||
|         def load_log_config(): | ||||
|             with open(log_config, 'r') as f: | ||||
|                 logging.config.dictConfig(yaml.load(f)) | ||||
| 
 | ||||
|         def sighup(signum, stack): | ||||
|             # it might be better to use a file watcher or something for this. | ||||
|             logging.info("Reloading log config from %s due to SIGHUP", | ||||
|                          log_config) | ||||
|             load_log_config() | ||||
| 
 | ||||
|         load_log_config() | ||||
| 
 | ||||
|     # TODO(paul): obviously this is a terrible mechanism for | ||||
|     #   stealing SIGHUP, because it means no other part of synapse | ||||
|     #   can use it instead. If we want to catch SIGHUP anywhere | ||||
|     #   else as well, I'd suggest we find a nicer way to broadcast | ||||
|     #   it around. | ||||
|     if getattr(signal, "SIGHUP"): | ||||
|         signal.signal(signal.SIGHUP, sighup) | ||||
| 
 | ||||
|     # It's critical to point twisted's internal logging somewhere, otherwise it | ||||
|     # stacks up and leaks kup to 64K object; | ||||
| @ -183,4 +213,7 @@ def setup_logging(log_config=None, log_file=None, verbosity=None): | ||||
|     # | ||||
|     # However this may not be too much of a problem if we are just writing to a file. | ||||
|     observer = STDLibLogObserver() | ||||
|     globalLogBeginner.beginLoggingTo([observer]) | ||||
|     globalLogBeginner.beginLoggingTo( | ||||
|         [observer], | ||||
|         redirectStandardIO=not config.no_redirect_stdio, | ||||
|     ) | ||||
|  | ||||
| @ -15,7 +15,6 @@ | ||||
| 
 | ||||
| from synapse.crypto.keyclient import fetch_server_key | ||||
| from synapse.api.errors import SynapseError, Codes | ||||
| from synapse.util.retryutils import get_retry_limiter | ||||
| from synapse.util import unwrapFirstError | ||||
| from synapse.util.async import ObservableDeferred | ||||
| from synapse.util.logcontext import ( | ||||
| @ -96,10 +95,11 @@ class Keyring(object): | ||||
|         verify_requests = [] | ||||
| 
 | ||||
|         for server_name, json_object in server_and_json: | ||||
|             logger.debug("Verifying for %s", server_name) | ||||
| 
 | ||||
|             key_ids = signature_ids(json_object, server_name) | ||||
|             if not key_ids: | ||||
|                 logger.warn("Request from %s: no supported signature keys", | ||||
|                             server_name) | ||||
|                 deferred = defer.fail(SynapseError( | ||||
|                     400, | ||||
|                     "Not signed with a supported algorithm", | ||||
| @ -108,6 +108,9 @@ class Keyring(object): | ||||
|             else: | ||||
|                 deferred = defer.Deferred() | ||||
| 
 | ||||
|             logger.debug("Verifying for %s with key_ids %s", | ||||
|                          server_name, key_ids) | ||||
| 
 | ||||
|             verify_request = VerifyKeyRequest( | ||||
|                 server_name, key_ids, json_object, deferred | ||||
|             ) | ||||
| @ -142,6 +145,9 @@ class Keyring(object): | ||||
| 
 | ||||
|             json_object = verify_request.json_object | ||||
| 
 | ||||
|             logger.debug("Got key %s %s:%s for server %s, verifying" % ( | ||||
|                 key_id, verify_key.alg, verify_key.version, server_name, | ||||
|             )) | ||||
|             try: | ||||
|                 verify_signed_json(json_object, server_name, verify_key) | ||||
|             except: | ||||
| @ -231,8 +237,14 @@ class Keyring(object): | ||||
|             d.addBoth(rm, server_name) | ||||
| 
 | ||||
|     def get_server_verify_keys(self, verify_requests): | ||||
|         """Takes a dict of KeyGroups and tries to find at least one key for | ||||
|         each group. | ||||
|         """Tries to find at least one key for each verify request | ||||
| 
 | ||||
|         For each verify_request, verify_request.deferred is called back with | ||||
|         params (server_name, key_id, VerifyKey) if a key is found, or errbacked | ||||
|         with a SynapseError if none of the keys are found. | ||||
| 
 | ||||
|         Args: | ||||
|             verify_requests (list[VerifyKeyRequest]): list of verify requests | ||||
|         """ | ||||
| 
 | ||||
|         # These are functions that produce keys given a list of key ids | ||||
| @ -245,8 +257,11 @@ class Keyring(object): | ||||
|         @defer.inlineCallbacks | ||||
|         def do_iterations(): | ||||
|             with Measure(self.clock, "get_server_verify_keys"): | ||||
|                 # dict[str, dict[str, VerifyKey]]: results so far. | ||||
|                 # map server_name -> key_id -> VerifyKey | ||||
|                 merged_results = {} | ||||
| 
 | ||||
|                 # dict[str, set(str)]: keys to fetch for each server | ||||
|                 missing_keys = {} | ||||
|                 for verify_request in verify_requests: | ||||
|                     missing_keys.setdefault(verify_request.server_name, set()).update( | ||||
| @ -308,6 +323,16 @@ class Keyring(object): | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_keys_from_store(self, server_name_and_key_ids): | ||||
|         """ | ||||
| 
 | ||||
|         Args: | ||||
|             server_name_and_key_ids (list[(str, iterable[str])]): | ||||
|                 list of (server_name, iterable[key_id]) tuples to fetch keys for | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred: resolves to dict[str, dict[str, VerifyKey]]: map from | ||||
|                 server_name -> key_id -> VerifyKey | ||||
|         """ | ||||
|         res = yield preserve_context_over_deferred(defer.gatherResults( | ||||
|             [ | ||||
|                 preserve_fn(self.store.get_server_verify_keys)( | ||||
| @ -356,30 +381,24 @@ class Keyring(object): | ||||
|     def get_keys_from_server(self, server_name_and_key_ids): | ||||
|         @defer.inlineCallbacks | ||||
|         def get_key(server_name, key_ids): | ||||
|             limiter = yield get_retry_limiter( | ||||
|                 server_name, | ||||
|                 self.clock, | ||||
|                 self.store, | ||||
|             ) | ||||
|             with limiter: | ||||
|                 keys = None | ||||
|                 try: | ||||
|                     keys = yield self.get_server_verify_key_v2_direct( | ||||
|                         server_name, key_ids | ||||
|                     ) | ||||
|                 except Exception as e: | ||||
|                     logger.info( | ||||
|                         "Unable to get key %r for %r directly: %s %s", | ||||
|                         key_ids, server_name, | ||||
|                         type(e).__name__, str(e.message), | ||||
|                     ) | ||||
|             keys = None | ||||
|             try: | ||||
|                 keys = yield self.get_server_verify_key_v2_direct( | ||||
|                     server_name, key_ids | ||||
|                 ) | ||||
|             except Exception as e: | ||||
|                 logger.info( | ||||
|                     "Unable to get key %r for %r directly: %s %s", | ||||
|                     key_ids, server_name, | ||||
|                     type(e).__name__, str(e.message), | ||||
|                 ) | ||||
| 
 | ||||
|                 if not keys: | ||||
|                     keys = yield self.get_server_verify_key_v1_direct( | ||||
|                         server_name, key_ids | ||||
|                     ) | ||||
|             if not keys: | ||||
|                 keys = yield self.get_server_verify_key_v1_direct( | ||||
|                     server_name, key_ids | ||||
|                 ) | ||||
| 
 | ||||
|                     keys = {server_name: keys} | ||||
|                 keys = {server_name: keys} | ||||
| 
 | ||||
|             defer.returnValue(keys) | ||||
| 
 | ||||
|  | ||||
| @ -15,6 +15,32 @@ | ||||
| 
 | ||||
| 
 | ||||
| class EventContext(object): | ||||
|     """ | ||||
|     Attributes: | ||||
|         current_state_ids (dict[(str, str), str]): | ||||
|             The current state map including the current event. | ||||
|             (type, state_key) -> event_id | ||||
| 
 | ||||
|         prev_state_ids (dict[(str, str), str]): | ||||
|             The current state map excluding the current event. | ||||
|             (type, state_key) -> event_id | ||||
| 
 | ||||
|         state_group (int): state group id | ||||
|         rejected (bool|str): A rejection reason if the event was rejected, else | ||||
|             False | ||||
| 
 | ||||
|         push_actions (list[(str, list[object])]): list of (user_id, actions) | ||||
|             tuples | ||||
| 
 | ||||
|         prev_group (int): Previously persisted state group. ``None`` for an | ||||
|             outlier. | ||||
|         delta_ids (dict[(str, str), str]): Delta from ``prev_group``. | ||||
|             (type, state_key) -> event_id. ``None`` for an outlier. | ||||
| 
 | ||||
|         prev_state_events (?): XXX: is this ever set to anything other than | ||||
|             the empty list? | ||||
|     """ | ||||
| 
 | ||||
|     __slots__ = [ | ||||
|         "current_state_ids", | ||||
|         "prev_state_ids", | ||||
|  | ||||
| @ -29,7 +29,7 @@ from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred | ||||
| from synapse.events import FrozenEvent, builder | ||||
| import synapse.metrics | ||||
| 
 | ||||
| from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination | ||||
| from synapse.util.retryutils import NotRetryingDestination | ||||
| 
 | ||||
| import copy | ||||
| import itertools | ||||
| @ -88,7 +88,7 @@ class FederationClient(FederationBase): | ||||
| 
 | ||||
|     @log_function | ||||
|     def make_query(self, destination, query_type, args, | ||||
|                    retry_on_dns_fail=False): | ||||
|                    retry_on_dns_fail=False, ignore_backoff=False): | ||||
|         """Sends a federation Query to a remote homeserver of the given type | ||||
|         and arguments. | ||||
| 
 | ||||
| @ -98,6 +98,8 @@ class FederationClient(FederationBase): | ||||
|                 handler name used in register_query_handler(). | ||||
|             args (dict): Mapping of strings to strings containing the details | ||||
|                 of the query request. | ||||
|             ignore_backoff (bool): true to ignore the historical backoff data | ||||
|                 and try the request anyway. | ||||
| 
 | ||||
|         Returns: | ||||
|             a Deferred which will eventually yield a JSON object from the | ||||
| @ -106,7 +108,8 @@ class FederationClient(FederationBase): | ||||
|         sent_queries_counter.inc(query_type) | ||||
| 
 | ||||
|         return self.transport_layer.make_query( | ||||
|             destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail | ||||
|             destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail, | ||||
|             ignore_backoff=ignore_backoff, | ||||
|         ) | ||||
| 
 | ||||
|     @log_function | ||||
| @ -234,31 +237,24 @@ class FederationClient(FederationBase): | ||||
|                 continue | ||||
| 
 | ||||
|             try: | ||||
|                 limiter = yield get_retry_limiter( | ||||
|                     destination, | ||||
|                     self._clock, | ||||
|                     self.store, | ||||
|                 transaction_data = yield self.transport_layer.get_event( | ||||
|                     destination, event_id, timeout=timeout, | ||||
|                 ) | ||||
| 
 | ||||
|                 with limiter: | ||||
|                     transaction_data = yield self.transport_layer.get_event( | ||||
|                         destination, event_id, timeout=timeout, | ||||
|                     ) | ||||
|                 logger.debug("transaction_data %r", transaction_data) | ||||
| 
 | ||||
|                     logger.debug("transaction_data %r", transaction_data) | ||||
|                 pdu_list = [ | ||||
|                     self.event_from_pdu_json(p, outlier=outlier) | ||||
|                     for p in transaction_data["pdus"] | ||||
|                 ] | ||||
| 
 | ||||
|                     pdu_list = [ | ||||
|                         self.event_from_pdu_json(p, outlier=outlier) | ||||
|                         for p in transaction_data["pdus"] | ||||
|                     ] | ||||
|                 if pdu_list and pdu_list[0]: | ||||
|                     pdu = pdu_list[0] | ||||
| 
 | ||||
|                     if pdu_list and pdu_list[0]: | ||||
|                         pdu = pdu_list[0] | ||||
|                     # Check signatures are correct. | ||||
|                     signed_pdu = yield self._check_sigs_and_hashes([pdu])[0] | ||||
| 
 | ||||
|                         # Check signatures are correct. | ||||
|                         signed_pdu = yield self._check_sigs_and_hashes([pdu])[0] | ||||
| 
 | ||||
|                         break | ||||
|                     break | ||||
| 
 | ||||
|                 pdu_attempts[destination] = now | ||||
| 
 | ||||
|  | ||||
| @ -52,7 +52,6 @@ class FederationServer(FederationBase): | ||||
| 
 | ||||
|         self.auth = hs.get_auth() | ||||
| 
 | ||||
|         self._room_pdu_linearizer = Linearizer("fed_room_pdu") | ||||
|         self._server_linearizer = Linearizer("fed_server") | ||||
| 
 | ||||
|         # We cache responses to state queries, as they take a while and often | ||||
| @ -147,11 +146,15 @@ class FederationServer(FederationBase): | ||||
|             # check that it's actually being sent from a valid destination to | ||||
|             # workaround bug #1753 in 0.18.5 and 0.18.6 | ||||
|             if transaction.origin != get_domain_from_id(pdu.event_id): | ||||
|                 # We continue to accept join events from any server; this is | ||||
|                 # necessary for the federation join dance to work correctly. | ||||
|                 # (When we join over federation, the "helper" server is | ||||
|                 # responsible for sending out the join event, rather than the | ||||
|                 # origin. See bug #1893). | ||||
|                 if not ( | ||||
|                     pdu.type == 'm.room.member' and | ||||
|                     pdu.content and | ||||
|                     pdu.content.get("membership", None) == 'join' and | ||||
|                     self.hs.is_mine_id(pdu.state_key) | ||||
|                     pdu.content.get("membership", None) == 'join' | ||||
|                 ): | ||||
|                     logger.info( | ||||
|                         "Discarding PDU %s from invalid origin %s", | ||||
| @ -165,7 +168,7 @@ class FederationServer(FederationBase): | ||||
|                     ) | ||||
| 
 | ||||
|             try: | ||||
|                 yield self._handle_new_pdu(transaction.origin, pdu) | ||||
|                 yield self._handle_received_pdu(transaction.origin, pdu) | ||||
|                 results.append({}) | ||||
|             except FederationError as e: | ||||
|                 self.send_failure(e, transaction.origin) | ||||
| @ -497,27 +500,16 @@ class FederationServer(FederationBase): | ||||
|         ) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     @log_function | ||||
|     def _handle_new_pdu(self, origin, pdu, get_missing=True): | ||||
|     def _handle_received_pdu(self, origin, pdu): | ||||
|         """ Process a PDU received in a federation /send/ transaction. | ||||
| 
 | ||||
|         # We reprocess pdus when we have seen them only as outliers | ||||
|         existing = yield self._get_persisted_pdu( | ||||
|             origin, pdu.event_id, do_auth=False | ||||
|         ) | ||||
| 
 | ||||
|         # FIXME: Currently we fetch an event again when we already have it | ||||
|         # if it has been marked as an outlier. | ||||
| 
 | ||||
|         already_seen = ( | ||||
|             existing and ( | ||||
|                 not existing.internal_metadata.is_outlier() | ||||
|                 or pdu.internal_metadata.is_outlier() | ||||
|             ) | ||||
|         ) | ||||
|         if already_seen: | ||||
|             logger.debug("Already seen pdu %s", pdu.event_id) | ||||
|             return | ||||
|         Args: | ||||
|             origin (str): server which sent the pdu | ||||
|             pdu (FrozenEvent): received pdu | ||||
| 
 | ||||
|         Returns (Deferred): completes with None | ||||
|         Raises: FederationError if the signatures / hash do not match | ||||
|     """ | ||||
|         # Check signature. | ||||
|         try: | ||||
|             pdu = yield self._check_sigs_and_hash(pdu) | ||||
| @ -529,143 +521,7 @@ class FederationServer(FederationBase): | ||||
|                 affected=pdu.event_id, | ||||
|             ) | ||||
| 
 | ||||
|         state = None | ||||
| 
 | ||||
|         auth_chain = [] | ||||
| 
 | ||||
|         have_seen = yield self.store.have_events( | ||||
|             [ev for ev, _ in pdu.prev_events] | ||||
|         ) | ||||
| 
 | ||||
|         fetch_state = False | ||||
| 
 | ||||
|         # Get missing pdus if necessary. | ||||
|         if not pdu.internal_metadata.is_outlier(): | ||||
|             # We only backfill backwards to the min depth. | ||||
|             min_depth = yield self.handler.get_min_depth_for_context( | ||||
|                 pdu.room_id | ||||
|             ) | ||||
| 
 | ||||
|             logger.debug( | ||||
|                 "_handle_new_pdu min_depth for %s: %d", | ||||
|                 pdu.room_id, min_depth | ||||
|             ) | ||||
| 
 | ||||
|             prevs = {e_id for e_id, _ in pdu.prev_events} | ||||
|             seen = set(have_seen.keys()) | ||||
| 
 | ||||
|             if min_depth and pdu.depth < min_depth: | ||||
|                 # This is so that we don't notify the user about this | ||||
|                 # message, to work around the fact that some events will | ||||
|                 # reference really really old events we really don't want to | ||||
|                 # send to the clients. | ||||
|                 pdu.internal_metadata.outlier = True | ||||
|             elif min_depth and pdu.depth > min_depth: | ||||
|                 if get_missing and prevs - seen: | ||||
|                     # If we're missing stuff, ensure we only fetch stuff one | ||||
|                     # at a time. | ||||
|                     logger.info( | ||||
|                         "Acquiring lock for room %r to fetch %d missing events: %r...", | ||||
|                         pdu.room_id, len(prevs - seen), list(prevs - seen)[:5], | ||||
|                     ) | ||||
|                     with (yield self._room_pdu_linearizer.queue(pdu.room_id)): | ||||
|                         logger.info( | ||||
|                             "Acquired lock for room %r to fetch %d missing events", | ||||
|                             pdu.room_id, len(prevs - seen), | ||||
|                         ) | ||||
| 
 | ||||
|                         # We recalculate seen, since it may have changed. | ||||
|                         have_seen = yield self.store.have_events(prevs) | ||||
|                         seen = set(have_seen.keys()) | ||||
| 
 | ||||
|                         if prevs - seen: | ||||
|                             latest = yield self.store.get_latest_event_ids_in_room( | ||||
|                                 pdu.room_id | ||||
|                             ) | ||||
| 
 | ||||
|                             # We add the prev events that we have seen to the latest | ||||
|                             # list to ensure the remote server doesn't give them to us | ||||
|                             latest = set(latest) | ||||
|                             latest |= seen | ||||
| 
 | ||||
|                             logger.info( | ||||
|                                 "Missing %d events for room %r: %r...", | ||||
|                                 len(prevs - seen), pdu.room_id, list(prevs - seen)[:5] | ||||
|                             ) | ||||
| 
 | ||||
|                             # XXX: we set timeout to 10s to help workaround | ||||
|                             # https://github.com/matrix-org/synapse/issues/1733. | ||||
|                             # The reason is to avoid holding the linearizer lock | ||||
|                             # whilst processing inbound /send transactions, causing | ||||
|                             # FDs to stack up and block other inbound transactions | ||||
|                             # which empirically can currently take up to 30 minutes. | ||||
|                             # | ||||
|                             # N.B. this explicitly disables retry attempts. | ||||
|                             # | ||||
|                             # N.B. this also increases our chances of falling back to | ||||
|                             # fetching fresh state for the room if the missing event | ||||
|                             # can't be found, which slightly reduces our security. | ||||
|                             # it may also increase our DAG extremity count for the room, | ||||
|                             # causing additional state resolution?  See #1760. | ||||
|                             # However, fetching state doesn't hold the linearizer lock | ||||
|                             # apparently. | ||||
|                             # | ||||
|                             # see https://github.com/matrix-org/synapse/pull/1744 | ||||
| 
 | ||||
|                             missing_events = yield self.get_missing_events( | ||||
|                                 origin, | ||||
|                                 pdu.room_id, | ||||
|                                 earliest_events_ids=list(latest), | ||||
|                                 latest_events=[pdu], | ||||
|                                 limit=10, | ||||
|                                 min_depth=min_depth, | ||||
|                                 timeout=10000, | ||||
|                             ) | ||||
| 
 | ||||
|                             # We want to sort these by depth so we process them and | ||||
|                             # tell clients about them in order. | ||||
|                             missing_events.sort(key=lambda x: x.depth) | ||||
| 
 | ||||
|                             for e in missing_events: | ||||
|                                 yield self._handle_new_pdu( | ||||
|                                     origin, | ||||
|                                     e, | ||||
|                                     get_missing=False | ||||
|                                 ) | ||||
| 
 | ||||
|                             have_seen = yield self.store.have_events( | ||||
|                                 [ev for ev, _ in pdu.prev_events] | ||||
|                             ) | ||||
| 
 | ||||
|             prevs = {e_id for e_id, _ in pdu.prev_events} | ||||
|             seen = set(have_seen.keys()) | ||||
|             if prevs - seen: | ||||
|                 logger.info( | ||||
|                     "Still missing %d events for room %r: %r...", | ||||
|                     len(prevs - seen), pdu.room_id, list(prevs - seen)[:5] | ||||
|                 ) | ||||
|                 fetch_state = True | ||||
| 
 | ||||
|         if fetch_state: | ||||
|             # We need to get the state at this event, since we haven't | ||||
|             # processed all the prev events. | ||||
|             logger.debug( | ||||
|                 "_handle_new_pdu getting state for %s", | ||||
|                 pdu.room_id | ||||
|             ) | ||||
|             try: | ||||
|                 state, auth_chain = yield self.get_state_for_room( | ||||
|                     origin, pdu.room_id, pdu.event_id, | ||||
|                 ) | ||||
|             except: | ||||
|                 logger.exception("Failed to get state for event: %s", pdu.event_id) | ||||
| 
 | ||||
|         yield self.handler.on_receive_pdu( | ||||
|             origin, | ||||
|             pdu, | ||||
|             state=state, | ||||
|             auth_chain=auth_chain, | ||||
|         ) | ||||
|         yield self.handler.on_receive_pdu(origin, pdu, get_missing=True) | ||||
| 
 | ||||
|     def __str__(self): | ||||
|         return "<ReplicationLayer(%s)>" % self.server_name | ||||
|  | ||||
| @ -54,6 +54,7 @@ class FederationRemoteSendQueue(object): | ||||
|     def __init__(self, hs): | ||||
|         self.server_name = hs.hostname | ||||
|         self.clock = hs.get_clock() | ||||
|         self.notifier = hs.get_notifier() | ||||
| 
 | ||||
|         self.presence_map = {} | ||||
|         self.presence_changed = sorteddict() | ||||
| @ -186,6 +187,8 @@ class FederationRemoteSendQueue(object): | ||||
|         else: | ||||
|             self.edus[pos] = edu | ||||
| 
 | ||||
|         self.notifier.on_new_replication_data() | ||||
| 
 | ||||
|     def send_presence(self, destination, states): | ||||
|         """As per TransactionQueue""" | ||||
|         pos = self._next_pos() | ||||
| @ -199,16 +202,20 @@ class FederationRemoteSendQueue(object): | ||||
|             (destination, state.user_id) for state in states | ||||
|         ] | ||||
| 
 | ||||
|         self.notifier.on_new_replication_data() | ||||
| 
 | ||||
|     def send_failure(self, failure, destination): | ||||
|         """As per TransactionQueue""" | ||||
|         pos = self._next_pos() | ||||
| 
 | ||||
|         self.failures[pos] = (destination, str(failure)) | ||||
|         self.notifier.on_new_replication_data() | ||||
| 
 | ||||
|     def send_device_messages(self, destination): | ||||
|         """As per TransactionQueue""" | ||||
|         pos = self._next_pos() | ||||
|         self.device_messages[pos] = destination | ||||
|         self.notifier.on_new_replication_data() | ||||
| 
 | ||||
|     def get_current_token(self): | ||||
|         return self.pos - 1 | ||||
|  | ||||
| @ -12,7 +12,7 @@ | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| import datetime | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| @ -22,9 +22,7 @@ from .units import Transaction, Edu | ||||
| from synapse.api.errors import HttpResponseException | ||||
| from synapse.util.async import run_on_reactor | ||||
| from synapse.util.logcontext import preserve_context_over_fn | ||||
| from synapse.util.retryutils import ( | ||||
|     get_retry_limiter, NotRetryingDestination, | ||||
| ) | ||||
| from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter | ||||
| from synapse.util.metrics import measure_func | ||||
| from synapse.types import get_domain_from_id | ||||
| from synapse.handlers.presence import format_user_presence_state | ||||
| @ -99,7 +97,12 @@ class TransactionQueue(object): | ||||
|         # destination -> list of tuple(failure, deferred) | ||||
|         self.pending_failures_by_dest = {} | ||||
| 
 | ||||
|         # destination -> stream_id of last successfully sent to-device message. | ||||
|         # NB: may be a long or an int. | ||||
|         self.last_device_stream_id_by_dest = {} | ||||
| 
 | ||||
|         # destination -> stream_id of last successfully sent device list | ||||
|         # update. | ||||
|         self.last_device_list_stream_id_by_dest = {} | ||||
| 
 | ||||
|         # HACK to get unique tx id | ||||
| @ -300,20 +303,20 @@ class TransactionQueue(object): | ||||
|             ) | ||||
|             return | ||||
| 
 | ||||
|         pending_pdus = [] | ||||
|         try: | ||||
|             self.pending_transactions[destination] = 1 | ||||
| 
 | ||||
|             # This will throw if we wouldn't retry. We do this here so we fail | ||||
|             # quickly, but we will later check this again in the http client, | ||||
|             # hence why we throw the result away. | ||||
|             yield get_retry_limiter(destination, self.clock, self.store) | ||||
| 
 | ||||
|             # XXX: what's this for? | ||||
|             yield run_on_reactor() | ||||
| 
 | ||||
|             pending_pdus = [] | ||||
|             while True: | ||||
|                 limiter = yield get_retry_limiter( | ||||
|                     destination, | ||||
|                     self.clock, | ||||
|                     self.store, | ||||
|                     backoff_on_404=True,  # If we get a 404 the other side has gone | ||||
|                 ) | ||||
| 
 | ||||
|                 device_message_edus, device_stream_id, dev_list_id = ( | ||||
|                     yield self._get_new_device_messages(destination) | ||||
|                 ) | ||||
| @ -369,7 +372,6 @@ class TransactionQueue(object): | ||||
| 
 | ||||
|                 success = yield self._send_new_transaction( | ||||
|                     destination, pending_pdus, pending_edus, pending_failures, | ||||
|                     limiter=limiter, | ||||
|                 ) | ||||
|                 if success: | ||||
|                     # Remove the acknowledged device messages from the database | ||||
| @ -387,12 +389,24 @@ class TransactionQueue(object): | ||||
|                     self.last_device_list_stream_id_by_dest[destination] = dev_list_id | ||||
|                 else: | ||||
|                     break | ||||
|         except NotRetryingDestination: | ||||
|         except NotRetryingDestination as e: | ||||
|             logger.debug( | ||||
|                 "TX [%s] not ready for retry yet - " | ||||
|                 "TX [%s] not ready for retry yet (next retry at %s) - " | ||||
|                 "dropping transaction for now", | ||||
|                 destination, | ||||
|                 datetime.datetime.fromtimestamp( | ||||
|                     (e.retry_last_ts + e.retry_interval) / 1000.0 | ||||
|                 ), | ||||
|             ) | ||||
|         except Exception as e: | ||||
|             logger.warn( | ||||
|                 "TX [%s] Failed to send transaction: %s", | ||||
|                 destination, | ||||
|                 e, | ||||
|             ) | ||||
|             for p, _ in pending_pdus: | ||||
|                 logger.info("Failed to send event %s to %s", p.event_id, | ||||
|                             destination) | ||||
|         finally: | ||||
|             # We want to be *very* sure we delete this after we stop processing | ||||
|             self.pending_transactions.pop(destination, None) | ||||
| @ -432,7 +446,7 @@ class TransactionQueue(object): | ||||
|     @measure_func("_send_new_transaction") | ||||
|     @defer.inlineCallbacks | ||||
|     def _send_new_transaction(self, destination, pending_pdus, pending_edus, | ||||
|                               pending_failures, limiter): | ||||
|                               pending_failures): | ||||
| 
 | ||||
|         # Sort based on the order field | ||||
|         pending_pdus.sort(key=lambda t: t[1]) | ||||
| @ -442,132 +456,104 @@ class TransactionQueue(object): | ||||
| 
 | ||||
|         success = True | ||||
| 
 | ||||
|         logger.debug("TX [%s] _attempt_new_transaction", destination) | ||||
| 
 | ||||
|         txn_id = str(self._next_txn_id) | ||||
| 
 | ||||
|         logger.debug( | ||||
|             "TX [%s] {%s} Attempting new transaction" | ||||
|             " (pdus: %d, edus: %d, failures: %d)", | ||||
|             destination, txn_id, | ||||
|             len(pdus), | ||||
|             len(edus), | ||||
|             len(failures) | ||||
|         ) | ||||
| 
 | ||||
|         logger.debug("TX [%s] Persisting transaction...", destination) | ||||
| 
 | ||||
|         transaction = Transaction.create_new( | ||||
|             origin_server_ts=int(self.clock.time_msec()), | ||||
|             transaction_id=txn_id, | ||||
|             origin=self.server_name, | ||||
|             destination=destination, | ||||
|             pdus=pdus, | ||||
|             edus=edus, | ||||
|             pdu_failures=failures, | ||||
|         ) | ||||
| 
 | ||||
|         self._next_txn_id += 1 | ||||
| 
 | ||||
|         yield self.transaction_actions.prepare_to_send(transaction) | ||||
| 
 | ||||
|         logger.debug("TX [%s] Persisted transaction", destination) | ||||
|         logger.info( | ||||
|             "TX [%s] {%s} Sending transaction [%s]," | ||||
|             " (PDUs: %d, EDUs: %d, failures: %d)", | ||||
|             destination, txn_id, | ||||
|             transaction.transaction_id, | ||||
|             len(pdus), | ||||
|             len(edus), | ||||
|             len(failures), | ||||
|         ) | ||||
| 
 | ||||
|         # Actually send the transaction | ||||
| 
 | ||||
|         # FIXME (erikj): This is a bit of a hack to make the Pdu age | ||||
|         # keys work | ||||
|         def json_data_cb(): | ||||
|             data = transaction.get_dict() | ||||
|             now = int(self.clock.time_msec()) | ||||
|             if "pdus" in data: | ||||
|                 for p in data["pdus"]: | ||||
|                     if "age_ts" in p: | ||||
|                         unsigned = p.setdefault("unsigned", {}) | ||||
|                         unsigned["age"] = now - int(p["age_ts"]) | ||||
|                         del p["age_ts"] | ||||
|             return data | ||||
| 
 | ||||
|         try: | ||||
|             logger.debug("TX [%s] _attempt_new_transaction", destination) | ||||
| 
 | ||||
|             txn_id = str(self._next_txn_id) | ||||
| 
 | ||||
|             logger.debug( | ||||
|                 "TX [%s] {%s} Attempting new transaction" | ||||
|                 " (pdus: %d, edus: %d, failures: %d)", | ||||
|                 destination, txn_id, | ||||
|                 len(pdus), | ||||
|                 len(edus), | ||||
|                 len(failures) | ||||
|             response = yield self.transport_layer.send_transaction( | ||||
|                 transaction, json_data_cb | ||||
|             ) | ||||
|             code = 200 | ||||
| 
 | ||||
|             logger.debug("TX [%s] Persisting transaction...", destination) | ||||
| 
 | ||||
|             transaction = Transaction.create_new( | ||||
|                 origin_server_ts=int(self.clock.time_msec()), | ||||
|                 transaction_id=txn_id, | ||||
|                 origin=self.server_name, | ||||
|                 destination=destination, | ||||
|                 pdus=pdus, | ||||
|                 edus=edus, | ||||
|                 pdu_failures=failures, | ||||
|             ) | ||||
| 
 | ||||
|             self._next_txn_id += 1 | ||||
| 
 | ||||
|             yield self.transaction_actions.prepare_to_send(transaction) | ||||
| 
 | ||||
|             logger.debug("TX [%s] Persisted transaction", destination) | ||||
|             logger.info( | ||||
|                 "TX [%s] {%s} Sending transaction [%s]," | ||||
|                 " (PDUs: %d, EDUs: %d, failures: %d)", | ||||
|                 destination, txn_id, | ||||
|                 transaction.transaction_id, | ||||
|                 len(pdus), | ||||
|                 len(edus), | ||||
|                 len(failures), | ||||
|             ) | ||||
| 
 | ||||
|             with limiter: | ||||
|                 # Actually send the transaction | ||||
| 
 | ||||
|                 # FIXME (erikj): This is a bit of a hack to make the Pdu age | ||||
|                 # keys work | ||||
|                 def json_data_cb(): | ||||
|                     data = transaction.get_dict() | ||||
|                     now = int(self.clock.time_msec()) | ||||
|                     if "pdus" in data: | ||||
|                         for p in data["pdus"]: | ||||
|                             if "age_ts" in p: | ||||
|                                 unsigned = p.setdefault("unsigned", {}) | ||||
|                                 unsigned["age"] = now - int(p["age_ts"]) | ||||
|                                 del p["age_ts"] | ||||
|                     return data | ||||
| 
 | ||||
|                 try: | ||||
|                     response = yield self.transport_layer.send_transaction( | ||||
|                         transaction, json_data_cb | ||||
|                     ) | ||||
|                     code = 200 | ||||
| 
 | ||||
|                     if response: | ||||
|                         for e_id, r in response.get("pdus", {}).items(): | ||||
|                             if "error" in r: | ||||
|                                 logger.warn( | ||||
|                                     "Transaction returned error for %s: %s", | ||||
|                                     e_id, r, | ||||
|                                 ) | ||||
|                 except HttpResponseException as e: | ||||
|                     code = e.code | ||||
|                     response = e.response | ||||
| 
 | ||||
|                     if e.code in (401, 404, 429) or 500 <= e.code: | ||||
|                         logger.info( | ||||
|                             "TX [%s] {%s} got %d response", | ||||
|                             destination, txn_id, code | ||||
|             if response: | ||||
|                 for e_id, r in response.get("pdus", {}).items(): | ||||
|                     if "error" in r: | ||||
|                         logger.warn( | ||||
|                             "Transaction returned error for %s: %s", | ||||
|                             e_id, r, | ||||
|                         ) | ||||
|                         raise e | ||||
|         except HttpResponseException as e: | ||||
|             code = e.code | ||||
|             response = e.response | ||||
| 
 | ||||
|             if e.code in (401, 404, 429) or 500 <= e.code: | ||||
|                 logger.info( | ||||
|                     "TX [%s] {%s} got %d response", | ||||
|                     destination, txn_id, code | ||||
|                 ) | ||||
|                 raise e | ||||
| 
 | ||||
|                 logger.debug("TX [%s] Sent transaction", destination) | ||||
|                 logger.debug("TX [%s] Marking as delivered...", destination) | ||||
|         logger.info( | ||||
|             "TX [%s] {%s} got %d response", | ||||
|             destination, txn_id, code | ||||
|         ) | ||||
| 
 | ||||
|             yield self.transaction_actions.delivered( | ||||
|                 transaction, code, response | ||||
|             ) | ||||
|         logger.debug("TX [%s] Sent transaction", destination) | ||||
|         logger.debug("TX [%s] Marking as delivered...", destination) | ||||
| 
 | ||||
|             logger.debug("TX [%s] Marked as delivered", destination) | ||||
|         yield self.transaction_actions.delivered( | ||||
|             transaction, code, response | ||||
|         ) | ||||
| 
 | ||||
|             if code != 200: | ||||
|                 for p in pdus: | ||||
|                     logger.info( | ||||
|                         "Failed to send event %s to %s", p.event_id, destination | ||||
|                     ) | ||||
|                 success = False | ||||
|         except RuntimeError as e: | ||||
|             # We capture this here as there as nothing actually listens | ||||
|             # for this finishing functions deferred. | ||||
|             logger.warn( | ||||
|                 "TX [%s] Problem in _attempt_transaction: %s", | ||||
|                 destination, | ||||
|                 e, | ||||
|             ) | ||||
| 
 | ||||
|             success = False | ||||
|         logger.debug("TX [%s] Marked as delivered", destination) | ||||
| 
 | ||||
|         if code != 200: | ||||
|             for p in pdus: | ||||
|                 logger.info("Failed to send event %s to %s", p.event_id, destination) | ||||
|         except Exception as e: | ||||
|             # We capture this here as there as nothing actually listens | ||||
|             # for this finishing functions deferred. | ||||
|             logger.warn( | ||||
|                 "TX [%s] Problem in _attempt_transaction: %s", | ||||
|                 destination, | ||||
|                 e, | ||||
|             ) | ||||
| 
 | ||||
|                 logger.info( | ||||
|                     "Failed to send event %s to %s", p.event_id, destination | ||||
|                 ) | ||||
|             success = False | ||||
| 
 | ||||
|             for p in pdus: | ||||
|                 logger.info("Failed to send event %s to %s", p.event_id, destination) | ||||
| 
 | ||||
|         defer.returnValue(success) | ||||
|  | ||||
| @ -163,6 +163,7 @@ class TransportLayerClient(object): | ||||
|             data=json_data, | ||||
|             json_data_callback=json_data_callback, | ||||
|             long_retries=True, | ||||
|             backoff_on_404=True,  # If we get a 404 the other side has gone | ||||
|         ) | ||||
| 
 | ||||
|         logger.debug( | ||||
| @ -174,7 +175,8 @@ class TransportLayerClient(object): | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     @log_function | ||||
|     def make_query(self, destination, query_type, args, retry_on_dns_fail): | ||||
|     def make_query(self, destination, query_type, args, retry_on_dns_fail, | ||||
|                    ignore_backoff=False): | ||||
|         path = PREFIX + "/query/%s" % query_type | ||||
| 
 | ||||
|         content = yield self.client.get_json( | ||||
| @ -183,6 +185,7 @@ class TransportLayerClient(object): | ||||
|             args=args, | ||||
|             retry_on_dns_fail=retry_on_dns_fail, | ||||
|             timeout=10000, | ||||
|             ignore_backoff=ignore_backoff, | ||||
|         ) | ||||
| 
 | ||||
|         defer.returnValue(content) | ||||
| @ -242,6 +245,7 @@ class TransportLayerClient(object): | ||||
|             destination=destination, | ||||
|             path=path, | ||||
|             data=content, | ||||
|             ignore_backoff=True, | ||||
|         ) | ||||
| 
 | ||||
|         defer.returnValue(response) | ||||
| @ -269,6 +273,7 @@ class TransportLayerClient(object): | ||||
|             destination=remote_server, | ||||
|             path=path, | ||||
|             args=args, | ||||
|             ignore_backoff=True, | ||||
|         ) | ||||
| 
 | ||||
|         defer.returnValue(response) | ||||
|  | ||||
| @ -1,5 +1,6 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2014 - 2016 OpenMarket Ltd | ||||
| # Copyright 2017 Vector Creations Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| @ -47,6 +48,7 @@ class AuthHandler(BaseHandler): | ||||
|             LoginType.PASSWORD: self._check_password_auth, | ||||
|             LoginType.RECAPTCHA: self._check_recaptcha, | ||||
|             LoginType.EMAIL_IDENTITY: self._check_email_identity, | ||||
|             LoginType.MSISDN: self._check_msisdn, | ||||
|             LoginType.DUMMY: self._check_dummy_auth, | ||||
|         } | ||||
|         self.bcrypt_rounds = hs.config.bcrypt_rounds | ||||
| @ -307,31 +309,47 @@ class AuthHandler(BaseHandler): | ||||
|                 defer.returnValue(True) | ||||
|         raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _check_email_identity(self, authdict, _): | ||||
|         return self._check_threepid('email', authdict) | ||||
| 
 | ||||
|     def _check_msisdn(self, authdict, _): | ||||
|         return self._check_threepid('msisdn', authdict) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _check_dummy_auth(self, authdict, _): | ||||
|         yield run_on_reactor() | ||||
|         defer.returnValue(True) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _check_threepid(self, medium, authdict): | ||||
|         yield run_on_reactor() | ||||
| 
 | ||||
|         if 'threepid_creds' not in authdict: | ||||
|             raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM) | ||||
| 
 | ||||
|         threepid_creds = authdict['threepid_creds'] | ||||
| 
 | ||||
|         identity_handler = self.hs.get_handlers().identity_handler | ||||
| 
 | ||||
|         logger.info("Getting validated threepid. threepidcreds: %r" % (threepid_creds,)) | ||||
|         logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,)) | ||||
|         threepid = yield identity_handler.threepid_from_creds(threepid_creds) | ||||
| 
 | ||||
|         if not threepid: | ||||
|             raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) | ||||
| 
 | ||||
|         if threepid['medium'] != medium: | ||||
|             raise LoginError( | ||||
|                 401, | ||||
|                 "Expecting threepid of type '%s', got '%s'" % ( | ||||
|                     medium, threepid['medium'], | ||||
|                 ), | ||||
|                 errcode=Codes.UNAUTHORIZED | ||||
|             ) | ||||
| 
 | ||||
|         threepid['threepid_creds'] = authdict['threepid_creds'] | ||||
| 
 | ||||
|         defer.returnValue(threepid) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _check_dummy_auth(self, authdict, _): | ||||
|         yield run_on_reactor() | ||||
|         defer.returnValue(True) | ||||
| 
 | ||||
|     def _get_params_recaptcha(self): | ||||
|         return {"public_key": self.hs.config.recaptcha_public_key} | ||||
| 
 | ||||
|  | ||||
| @ -169,6 +169,40 @@ class DeviceHandler(BaseHandler): | ||||
| 
 | ||||
|         yield self.notify_device_update(user_id, [device_id]) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def delete_devices(self, user_id, device_ids): | ||||
|         """ Delete several devices | ||||
| 
 | ||||
|         Args: | ||||
|             user_id (str): | ||||
|             device_ids (str): The list of device IDs to delete | ||||
| 
 | ||||
|         Returns: | ||||
|             defer.Deferred: | ||||
|         """ | ||||
| 
 | ||||
|         try: | ||||
|             yield self.store.delete_devices(user_id, device_ids) | ||||
|         except errors.StoreError, e: | ||||
|             if e.code == 404: | ||||
|                 # no match | ||||
|                 pass | ||||
|             else: | ||||
|                 raise | ||||
| 
 | ||||
|         # Delete access tokens and e2e keys for each device. Not optimised as it is not | ||||
|         # considered as part of a critical path. | ||||
|         for device_id in device_ids: | ||||
|             yield self.store.user_delete_access_tokens( | ||||
|                 user_id, device_id=device_id, | ||||
|                 delete_refresh_tokens=True, | ||||
|             ) | ||||
|             yield self.store.delete_e2e_keys_by_device( | ||||
|                 user_id=user_id, device_id=device_id | ||||
|             ) | ||||
| 
 | ||||
|         yield self.notify_device_update(user_id, device_ids) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def update_device(self, user_id, device_id, content): | ||||
|         """ Update the given device | ||||
| @ -214,8 +248,7 @@ class DeviceHandler(BaseHandler): | ||||
|             user_id, device_ids, list(hosts) | ||||
|         ) | ||||
| 
 | ||||
|         rooms = yield self.store.get_rooms_for_user(user_id) | ||||
|         room_ids = [r.room_id for r in rooms] | ||||
|         room_ids = yield self.store.get_rooms_for_user(user_id) | ||||
| 
 | ||||
|         yield self.notifier.on_new_event( | ||||
|             "device_list_key", position, rooms=room_ids, | ||||
| @ -236,8 +269,7 @@ class DeviceHandler(BaseHandler): | ||||
|             user_id (str) | ||||
|             from_token (StreamToken) | ||||
|         """ | ||||
|         rooms = yield self.store.get_rooms_for_user(user_id) | ||||
|         room_ids = set(r.room_id for r in rooms) | ||||
|         room_ids = yield self.store.get_rooms_for_user(user_id) | ||||
| 
 | ||||
|         # First we check if any devices have changed | ||||
|         changed = yield self.store.get_user_whose_devices_changed( | ||||
| @ -262,7 +294,7 @@ class DeviceHandler(BaseHandler): | ||||
|                 # ordering: treat it the same as a new room | ||||
|                 event_ids = [] | ||||
| 
 | ||||
|             current_state_ids = yield self.state.get_current_state_ids(room_id) | ||||
|             current_state_ids = yield self.store.get_current_state_ids(room_id) | ||||
| 
 | ||||
|             # special-case for an empty prev state: include all members | ||||
|             # in the changed list | ||||
| @ -313,8 +345,8 @@ class DeviceHandler(BaseHandler): | ||||
|     @defer.inlineCallbacks | ||||
|     def user_left_room(self, user, room_id): | ||||
|         user_id = user.to_string() | ||||
|         rooms = yield self.store.get_rooms_for_user(user_id) | ||||
|         if not rooms: | ||||
|         room_ids = yield self.store.get_rooms_for_user(user_id) | ||||
|         if not room_ids: | ||||
|             # We no longer share rooms with this user, so we'll no longer | ||||
|             # receive device updates. Mark this in DB. | ||||
|             yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id) | ||||
| @ -370,8 +402,8 @@ class DeviceListEduUpdater(object): | ||||
|             logger.warning("Got device list update edu for %r from %r", user_id, origin) | ||||
|             return | ||||
| 
 | ||||
|         rooms = yield self.store.get_rooms_for_user(user_id) | ||||
|         if not rooms: | ||||
|         room_ids = yield self.store.get_rooms_for_user(user_id) | ||||
|         if not room_ids: | ||||
|             # We don't share any rooms with this user. Ignore update, as we | ||||
|             # probably won't get any further updates. | ||||
|             return | ||||
|  | ||||
| @ -175,6 +175,7 @@ class DirectoryHandler(BaseHandler): | ||||
|                         "room_alias": room_alias.to_string(), | ||||
|                     }, | ||||
|                     retry_on_dns_fail=False, | ||||
|                     ignore_backoff=True, | ||||
|                 ) | ||||
|             except CodeMessageException as e: | ||||
|                 logging.warn("Error retrieving alias") | ||||
|  | ||||
| @ -22,7 +22,7 @@ from twisted.internet import defer | ||||
| from synapse.api.errors import SynapseError, CodeMessageException | ||||
| from synapse.types import get_domain_from_id | ||||
| from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred | ||||
| from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination | ||||
| from synapse.util.retryutils import NotRetryingDestination | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| @ -121,15 +121,11 @@ class E2eKeysHandler(object): | ||||
|         def do_remote_query(destination): | ||||
|             destination_query = remote_queries_not_in_cache[destination] | ||||
|             try: | ||||
|                 limiter = yield get_retry_limiter( | ||||
|                     destination, self.clock, self.store | ||||
|                 remote_result = yield self.federation.query_client_keys( | ||||
|                     destination, | ||||
|                     {"device_keys": destination_query}, | ||||
|                     timeout=timeout | ||||
|                 ) | ||||
|                 with limiter: | ||||
|                     remote_result = yield self.federation.query_client_keys( | ||||
|                         destination, | ||||
|                         {"device_keys": destination_query}, | ||||
|                         timeout=timeout | ||||
|                     ) | ||||
| 
 | ||||
|                 for user_id, keys in remote_result["device_keys"].items(): | ||||
|                     if user_id in destination_query: | ||||
| @ -239,18 +235,14 @@ class E2eKeysHandler(object): | ||||
|         def claim_client_keys(destination): | ||||
|             device_keys = remote_queries[destination] | ||||
|             try: | ||||
|                 limiter = yield get_retry_limiter( | ||||
|                     destination, self.clock, self.store | ||||
|                 remote_result = yield self.federation.claim_client_keys( | ||||
|                     destination, | ||||
|                     {"one_time_keys": device_keys}, | ||||
|                     timeout=timeout | ||||
|                 ) | ||||
|                 with limiter: | ||||
|                     remote_result = yield self.federation.claim_client_keys( | ||||
|                         destination, | ||||
|                         {"one_time_keys": device_keys}, | ||||
|                         timeout=timeout | ||||
|                     ) | ||||
|                     for user_id, keys in remote_result["one_time_keys"].items(): | ||||
|                         if user_id in device_keys: | ||||
|                             json_result[user_id] = keys | ||||
|                 for user_id, keys in remote_result["one_time_keys"].items(): | ||||
|                     if user_id in device_keys: | ||||
|                         json_result[user_id] = keys | ||||
|             except CodeMessageException as e: | ||||
|                 failures[destination] = { | ||||
|                     "status": e.code, "message": e.message | ||||
| @ -316,7 +308,7 @@ class E2eKeysHandler(object): | ||||
|         # old access_token without an associated device_id. Either way, we | ||||
|         # need to double-check the device is registered to avoid ending up with | ||||
|         # keys without a corresponding device. | ||||
|         self.device_handler.check_device_registered(user_id, device_id) | ||||
|         yield self.device_handler.check_device_registered(user_id, device_id) | ||||
| 
 | ||||
|         result = yield self.store.count_e2e_one_time_keys(user_id, device_id) | ||||
| 
 | ||||
|  | ||||
| @ -14,6 +14,7 @@ | ||||
| # limitations under the License. | ||||
| 
 | ||||
| """Contains handlers for federation events.""" | ||||
| import synapse.util.logcontext | ||||
| from signedjson.key import decode_verify_key_bytes | ||||
| from signedjson.sign import verify_signed_json | ||||
| from unpaddedbase64 import decode_base64 | ||||
| @ -31,7 +32,7 @@ from synapse.util.logcontext import ( | ||||
| ) | ||||
| from synapse.util.metrics import measure_func | ||||
| from synapse.util.logutils import log_function | ||||
| from synapse.util.async import run_on_reactor | ||||
| from synapse.util.async import run_on_reactor, Linearizer | ||||
| from synapse.util.frozenutils import unfreeze | ||||
| from synapse.crypto.event_signing import ( | ||||
|     compute_event_signature, add_hashes_and_signatures, | ||||
| @ -79,29 +80,216 @@ class FederationHandler(BaseHandler): | ||||
| 
 | ||||
|         # When joining a room we need to queue any events for that room up | ||||
|         self.room_queues = {} | ||||
|         self._room_pdu_linearizer = Linearizer("fed_room_pdu") | ||||
| 
 | ||||
|     @log_function | ||||
|     @defer.inlineCallbacks | ||||
|     def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None): | ||||
|         """ Called by the ReplicationLayer when we have a new pdu. We need to | ||||
|         do auth checks and put it through the StateHandler. | ||||
|     @log_function | ||||
|     def on_receive_pdu(self, origin, pdu, get_missing=True): | ||||
|         """ Process a PDU received via a federation /send/ transaction, or | ||||
|         via backfill of missing prev_events | ||||
| 
 | ||||
|         auth_chain and state are None if we already have the necessary state | ||||
|         and prev_events in the db | ||||
|         Args: | ||||
|             origin (str): server which initiated the /send/ transaction. Will | ||||
|                 be used to fetch missing events or state. | ||||
|             pdu (FrozenEvent): received PDU | ||||
|             get_missing (bool): True if we should fetch missing prev_events | ||||
| 
 | ||||
|         Returns (Deferred): completes with None | ||||
|         """ | ||||
|         event = pdu | ||||
| 
 | ||||
|         logger.debug("Got event: %s", event.event_id) | ||||
|         # We reprocess pdus when we have seen them only as outliers | ||||
|         existing = yield self.get_persisted_pdu( | ||||
|             origin, pdu.event_id, do_auth=False | ||||
|         ) | ||||
| 
 | ||||
|         # FIXME: Currently we fetch an event again when we already have it | ||||
|         # if it has been marked as an outlier. | ||||
| 
 | ||||
|         already_seen = ( | ||||
|             existing and ( | ||||
|                 not existing.internal_metadata.is_outlier() | ||||
|                 or pdu.internal_metadata.is_outlier() | ||||
|             ) | ||||
|         ) | ||||
|         if already_seen: | ||||
|             logger.debug("Already seen pdu %s", pdu.event_id) | ||||
|             return | ||||
| 
 | ||||
|         # If we are currently in the process of joining this room, then we | ||||
|         # queue up events for later processing. | ||||
|         if event.room_id in self.room_queues: | ||||
|             self.room_queues[event.room_id].append((pdu, origin)) | ||||
|         if pdu.room_id in self.room_queues: | ||||
|             logger.info("Ignoring PDU %s for room %s from %s for now; join " | ||||
|                         "in progress", pdu.event_id, pdu.room_id, origin) | ||||
|             self.room_queues[pdu.room_id].append((pdu, origin)) | ||||
|             return | ||||
| 
 | ||||
|         logger.debug("Processing event: %s", event.event_id) | ||||
|         state = None | ||||
| 
 | ||||
|         logger.debug("Event: %s", event) | ||||
|         auth_chain = [] | ||||
| 
 | ||||
|         have_seen = yield self.store.have_events( | ||||
|             [ev for ev, _ in pdu.prev_events] | ||||
|         ) | ||||
| 
 | ||||
|         fetch_state = False | ||||
| 
 | ||||
|         # Get missing pdus if necessary. | ||||
|         if not pdu.internal_metadata.is_outlier(): | ||||
|             # We only backfill backwards to the min depth. | ||||
|             min_depth = yield self.get_min_depth_for_context( | ||||
|                 pdu.room_id | ||||
|             ) | ||||
| 
 | ||||
|             logger.debug( | ||||
|                 "_handle_new_pdu min_depth for %s: %d", | ||||
|                 pdu.room_id, min_depth | ||||
|             ) | ||||
| 
 | ||||
|             prevs = {e_id for e_id, _ in pdu.prev_events} | ||||
|             seen = set(have_seen.keys()) | ||||
| 
 | ||||
|             if min_depth and pdu.depth < min_depth: | ||||
|                 # This is so that we don't notify the user about this | ||||
|                 # message, to work around the fact that some events will | ||||
|                 # reference really really old events we really don't want to | ||||
|                 # send to the clients. | ||||
|                 pdu.internal_metadata.outlier = True | ||||
|             elif min_depth and pdu.depth > min_depth: | ||||
|                 if get_missing and prevs - seen: | ||||
|                     # If we're missing stuff, ensure we only fetch stuff one | ||||
|                     # at a time. | ||||
|                     logger.info( | ||||
|                         "Acquiring lock for room %r to fetch %d missing events: %r...", | ||||
|                         pdu.room_id, len(prevs - seen), list(prevs - seen)[:5], | ||||
|                     ) | ||||
|                     with (yield self._room_pdu_linearizer.queue(pdu.room_id)): | ||||
|                         logger.info( | ||||
|                             "Acquired lock for room %r to fetch %d missing events", | ||||
|                             pdu.room_id, len(prevs - seen), | ||||
|                         ) | ||||
| 
 | ||||
|                         yield self._get_missing_events_for_pdu( | ||||
|                             origin, pdu, prevs, min_depth | ||||
|                         ) | ||||
| 
 | ||||
|             prevs = {e_id for e_id, _ in pdu.prev_events} | ||||
|             seen = set(have_seen.keys()) | ||||
|             if prevs - seen: | ||||
|                 logger.info( | ||||
|                     "Still missing %d events for room %r: %r...", | ||||
|                     len(prevs - seen), pdu.room_id, list(prevs - seen)[:5] | ||||
|                 ) | ||||
|                 fetch_state = True | ||||
| 
 | ||||
|         if fetch_state: | ||||
|             # We need to get the state at this event, since we haven't | ||||
|             # processed all the prev events. | ||||
|             logger.debug( | ||||
|                 "_handle_new_pdu getting state for %s", | ||||
|                 pdu.room_id | ||||
|             ) | ||||
|             try: | ||||
|                 state, auth_chain = yield self.replication_layer.get_state_for_room( | ||||
|                     origin, pdu.room_id, pdu.event_id, | ||||
|                 ) | ||||
|             except: | ||||
|                 logger.exception("Failed to get state for event: %s", pdu.event_id) | ||||
| 
 | ||||
|         yield self._process_received_pdu( | ||||
|             origin, | ||||
|             pdu, | ||||
|             state=state, | ||||
|             auth_chain=auth_chain, | ||||
|         ) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): | ||||
|         """ | ||||
|         Args: | ||||
|             origin (str): Origin of the pdu. Will be called to get the missing events | ||||
|             pdu: received pdu | ||||
|             prevs (str[]): List of event ids which we are missing | ||||
|             min_depth (int): Minimum depth of events to return. | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred<dict(str, str?)>: updated have_seen dictionary | ||||
|         """ | ||||
|         # We recalculate seen, since it may have changed. | ||||
|         have_seen = yield self.store.have_events(prevs) | ||||
|         seen = set(have_seen.keys()) | ||||
| 
 | ||||
|         if not prevs - seen: | ||||
|             # nothing left to do | ||||
|             defer.returnValue(have_seen) | ||||
| 
 | ||||
|         latest = yield self.store.get_latest_event_ids_in_room( | ||||
|             pdu.room_id | ||||
|         ) | ||||
| 
 | ||||
|         # We add the prev events that we have seen to the latest | ||||
|         # list to ensure the remote server doesn't give them to us | ||||
|         latest = set(latest) | ||||
|         latest |= seen | ||||
| 
 | ||||
|         logger.info( | ||||
|             "Missing %d events for room %r: %r...", | ||||
|             len(prevs - seen), pdu.room_id, list(prevs - seen)[:5] | ||||
|         ) | ||||
| 
 | ||||
|         # XXX: we set timeout to 10s to help workaround | ||||
|         # https://github.com/matrix-org/synapse/issues/1733. | ||||
|         # The reason is to avoid holding the linearizer lock | ||||
|         # whilst processing inbound /send transactions, causing | ||||
|         # FDs to stack up and block other inbound transactions | ||||
|         # which empirically can currently take up to 30 minutes. | ||||
|         # | ||||
|         # N.B. this explicitly disables retry attempts. | ||||
|         # | ||||
|         # N.B. this also increases our chances of falling back to | ||||
|         # fetching fresh state for the room if the missing event | ||||
|         # can't be found, which slightly reduces our security. | ||||
|         # it may also increase our DAG extremity count for the room, | ||||
|         # causing additional state resolution?  See #1760. | ||||
|         # However, fetching state doesn't hold the linearizer lock | ||||
|         # apparently. | ||||
|         # | ||||
|         # see https://github.com/matrix-org/synapse/pull/1744 | ||||
| 
 | ||||
|         missing_events = yield self.replication_layer.get_missing_events( | ||||
|             origin, | ||||
|             pdu.room_id, | ||||
|             earliest_events_ids=list(latest), | ||||
|             latest_events=[pdu], | ||||
|             limit=10, | ||||
|             min_depth=min_depth, | ||||
|             timeout=10000, | ||||
|         ) | ||||
| 
 | ||||
|         # We want to sort these by depth so we process them and | ||||
|         # tell clients about them in order. | ||||
|         missing_events.sort(key=lambda x: x.depth) | ||||
| 
 | ||||
|         for e in missing_events: | ||||
|             yield self.on_receive_pdu( | ||||
|                 origin, | ||||
|                 e, | ||||
|                 get_missing=False | ||||
|             ) | ||||
| 
 | ||||
|         have_seen = yield self.store.have_events( | ||||
|             [ev for ev, _ in pdu.prev_events] | ||||
|         ) | ||||
|         defer.returnValue(have_seen) | ||||
| 
 | ||||
|     @log_function | ||||
|     @defer.inlineCallbacks | ||||
|     def _process_received_pdu(self, origin, pdu, state, auth_chain): | ||||
|         """ Called when we have a new pdu. We need to do auth checks and put it | ||||
|         through the StateHandler. | ||||
|         """ | ||||
|         event = pdu | ||||
| 
 | ||||
|         logger.debug("Processing event: %s", event) | ||||
| 
 | ||||
|         # FIXME (erikj): Awful hack to make the case where we are not currently | ||||
|         # in the room work | ||||
| @ -670,8 +858,6 @@ class FederationHandler(BaseHandler): | ||||
|         """ | ||||
|         logger.debug("Joining %s to %s", joinee, room_id) | ||||
| 
 | ||||
|         yield self.store.clean_room_for_join(room_id) | ||||
| 
 | ||||
|         origin, event = yield self._make_and_verify_event( | ||||
|             target_hosts, | ||||
|             room_id, | ||||
| @ -680,7 +866,15 @@ class FederationHandler(BaseHandler): | ||||
|             content, | ||||
|         ) | ||||
| 
 | ||||
|         # This shouldn't happen, because the RoomMemberHandler has a | ||||
|         # linearizer lock which only allows one operation per user per room | ||||
|         # at a time - so this is just paranoia. | ||||
|         assert (room_id not in self.room_queues) | ||||
| 
 | ||||
|         self.room_queues[room_id] = [] | ||||
| 
 | ||||
|         yield self.store.clean_room_for_join(room_id) | ||||
| 
 | ||||
|         handled_events = set() | ||||
| 
 | ||||
|         try: | ||||
| @ -733,17 +927,36 @@ class FederationHandler(BaseHandler): | ||||
|             room_queue = self.room_queues[room_id] | ||||
|             del self.room_queues[room_id] | ||||
| 
 | ||||
|             for p, origin in room_queue: | ||||
|                 if p.event_id in handled_events: | ||||
|                     continue | ||||
|             # we don't need to wait for the queued events to be processed - | ||||
|             # it's just a best-effort thing at this point. We do want to do | ||||
|             # them roughly in order, though, otherwise we'll end up making | ||||
|             # lots of requests for missing prev_events which we do actually | ||||
|             # have. Hence we fire off the deferred, but don't wait for it. | ||||
| 
 | ||||
|                 try: | ||||
|                     self.on_receive_pdu(origin, p) | ||||
|                 except: | ||||
|                     logger.exception("Couldn't handle pdu") | ||||
|             synapse.util.logcontext.preserve_fn(self._handle_queued_pdus)( | ||||
|                 room_queue | ||||
|             ) | ||||
| 
 | ||||
|         defer.returnValue(True) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _handle_queued_pdus(self, room_queue): | ||||
|         """Process PDUs which got queued up while we were busy send_joining. | ||||
| 
 | ||||
|         Args: | ||||
|             room_queue (list[FrozenEvent, str]): list of PDUs to be processed | ||||
|                 and the servers that sent them | ||||
|         """ | ||||
|         for p, origin in room_queue: | ||||
|             try: | ||||
|                 logger.info("Processing queued PDU %s which was received " | ||||
|                             "while we were joining %s", p.event_id, p.room_id) | ||||
|                 yield self.on_receive_pdu(origin, p) | ||||
|             except Exception as e: | ||||
|                 logger.warn( | ||||
|                     "Error handling queued PDU %s from %s: %s", | ||||
|                     p.event_id, origin, e) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     @log_function | ||||
|     def on_make_join_request(self, room_id, user_id): | ||||
| @ -791,9 +1004,19 @@ class FederationHandler(BaseHandler): | ||||
|         ) | ||||
| 
 | ||||
|         event.internal_metadata.outlier = False | ||||
|         # Send this event on behalf of the origin server since they may not | ||||
|         # have an up to data view of the state of the room at this event so | ||||
|         # will not know which servers to send the event to. | ||||
|         # Send this event on behalf of the origin server. | ||||
|         # | ||||
|         # The reasons we have the destination server rather than the origin | ||||
|         # server send it are slightly mysterious: the origin server should have | ||||
|         # all the neccessary state once it gets the response to the send_join, | ||||
|         # so it could send the event itself if it wanted to. It may be that | ||||
|         # doing it this way reduces failure modes, or avoids certain attacks | ||||
|         # where a new server selectively tells a subset of the federation that | ||||
|         # it has joined. | ||||
|         # | ||||
|         # The fact is that, as of the current writing, Synapse doesn't send out | ||||
|         # the join event over federation after joining, and changing it now | ||||
|         # would introduce the danger of backwards-compatibility problems. | ||||
|         event.internal_metadata.send_on_behalf_of = origin | ||||
| 
 | ||||
|         context, event_stream_id, max_stream_id = yield self._handle_new_event( | ||||
| @ -878,15 +1101,15 @@ class FederationHandler(BaseHandler): | ||||
|                 user_id, | ||||
|                 "leave" | ||||
|             ) | ||||
|             signed_event = self._sign_event(event) | ||||
|             event = self._sign_event(event) | ||||
|         except SynapseError: | ||||
|             raise | ||||
|         except CodeMessageException as e: | ||||
|             logger.warn("Failed to reject invite: %s", e) | ||||
|             raise SynapseError(500, "Failed to reject invite") | ||||
| 
 | ||||
|         # Try the host we successfully got a response to /make_join/ | ||||
|         # request first. | ||||
|         # Try the host that we succesfully called /make_leave/ on first for | ||||
|         # the /send_leave/ request. | ||||
|         try: | ||||
|             target_hosts.remove(origin) | ||||
|             target_hosts.insert(0, origin) | ||||
| @ -896,7 +1119,7 @@ class FederationHandler(BaseHandler): | ||||
|         try: | ||||
|             yield self.replication_layer.send_leave( | ||||
|                 target_hosts, | ||||
|                 signed_event | ||||
|                 event | ||||
|             ) | ||||
|         except SynapseError: | ||||
|             raise | ||||
| @ -1325,7 +1548,17 @@ class FederationHandler(BaseHandler): | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _prep_event(self, origin, event, state=None, auth_events=None): | ||||
|         """ | ||||
| 
 | ||||
|         Args: | ||||
|             origin: | ||||
|             event: | ||||
|             state: | ||||
|             auth_events: | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred, which resolves to synapse.events.snapshot.EventContext | ||||
|         """ | ||||
|         context = yield self.state_handler.compute_event_context( | ||||
|             event, old_state=state, | ||||
|         ) | ||||
|  | ||||
| @ -1,5 +1,6 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2015, 2016 OpenMarket Ltd | ||||
| # Copyright 2017 Vector Creations Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| @ -150,7 +151,7 @@ class IdentityHandler(BaseHandler): | ||||
|         params.update(kwargs) | ||||
| 
 | ||||
|         try: | ||||
|             data = yield self.http_client.post_urlencoded_get_json( | ||||
|             data = yield self.http_client.post_json_get_json( | ||||
|                 "https://%s%s" % ( | ||||
|                     id_server, | ||||
|                     "/_matrix/identity/api/v1/validate/email/requestToken" | ||||
| @ -161,3 +162,37 @@ class IdentityHandler(BaseHandler): | ||||
|         except CodeMessageException as e: | ||||
|             logger.info("Proxied requestToken failed: %r", e) | ||||
|             raise e | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def requestMsisdnToken( | ||||
|             self, id_server, country, phone_number, | ||||
|             client_secret, send_attempt, **kwargs | ||||
|     ): | ||||
|         yield run_on_reactor() | ||||
| 
 | ||||
|         if not self._should_trust_id_server(id_server): | ||||
|             raise SynapseError( | ||||
|                 400, "Untrusted ID server '%s'" % id_server, | ||||
|                 Codes.SERVER_NOT_TRUSTED | ||||
|             ) | ||||
| 
 | ||||
|         params = { | ||||
|             'country': country, | ||||
|             'phone_number': phone_number, | ||||
|             'client_secret': client_secret, | ||||
|             'send_attempt': send_attempt, | ||||
|         } | ||||
|         params.update(kwargs) | ||||
| 
 | ||||
|         try: | ||||
|             data = yield self.http_client.post_json_get_json( | ||||
|                 "https://%s%s" % ( | ||||
|                     id_server, | ||||
|                     "/_matrix/identity/api/v1/validate/msisdn/requestToken" | ||||
|                 ), | ||||
|                 params | ||||
|             ) | ||||
|             defer.returnValue(data) | ||||
|         except CodeMessageException as e: | ||||
|             logger.info("Proxied requestToken failed: %r", e) | ||||
|             raise e | ||||
|  | ||||
| @ -19,6 +19,7 @@ from synapse.api.constants import EventTypes, Membership | ||||
| from synapse.api.errors import AuthError, Codes | ||||
| from synapse.events.utils import serialize_event | ||||
| from synapse.events.validator import EventValidator | ||||
| from synapse.handlers.presence import format_user_presence_state | ||||
| from synapse.streams.config import PaginationConfig | ||||
| from synapse.types import ( | ||||
|     UserID, StreamToken, | ||||
| @ -225,9 +226,17 @@ class InitialSyncHandler(BaseHandler): | ||||
|                 "content": content, | ||||
|             }) | ||||
| 
 | ||||
|         now = self.clock.time_msec() | ||||
| 
 | ||||
|         ret = { | ||||
|             "rooms": rooms_ret, | ||||
|             "presence": presence, | ||||
|             "presence": [ | ||||
|                 { | ||||
|                     "type": "m.presence", | ||||
|                     "content": format_user_presence_state(event, now), | ||||
|                 } | ||||
|                 for event in presence | ||||
|             ], | ||||
|             "account_data": account_data_events, | ||||
|             "receipts": receipt, | ||||
|             "end": now_token.to_string(), | ||||
|  | ||||
| @ -29,6 +29,7 @@ from synapse.api.errors import SynapseError | ||||
| from synapse.api.constants import PresenceState | ||||
| from synapse.storage.presence import UserPresenceState | ||||
| 
 | ||||
| from synapse.util.caches.descriptors import cachedInlineCallbacks | ||||
| from synapse.util.logcontext import preserve_fn | ||||
| from synapse.util.logutils import log_function | ||||
| from synapse.util.metrics import Measure | ||||
| @ -556,9 +557,9 @@ class PresenceHandler(object): | ||||
|         room_ids_to_states = {} | ||||
|         users_to_states = {} | ||||
|         for state in states: | ||||
|             events = yield self.store.get_rooms_for_user(state.user_id) | ||||
|             for e in events: | ||||
|                 room_ids_to_states.setdefault(e.room_id, []).append(state) | ||||
|             room_ids = yield self.store.get_rooms_for_user(state.user_id) | ||||
|             for room_id in room_ids: | ||||
|                 room_ids_to_states.setdefault(room_id, []).append(state) | ||||
| 
 | ||||
|             plist = yield self.store.get_presence_list_observers_accepted(state.user_id) | ||||
|             for u in plist: | ||||
| @ -574,8 +575,7 @@ class PresenceHandler(object): | ||||
|                 if not local_states: | ||||
|                     continue | ||||
| 
 | ||||
|                 users = yield self.store.get_users_in_room(room_id) | ||||
|                 hosts = set(get_domain_from_id(u) for u in users) | ||||
|                 hosts = yield self.store.get_hosts_in_room(room_id) | ||||
| 
 | ||||
|                 for host in hosts: | ||||
|                     hosts_to_states.setdefault(host, []).extend(local_states) | ||||
| @ -719,9 +719,7 @@ class PresenceHandler(object): | ||||
|                 for state in updates | ||||
|             ]) | ||||
|         else: | ||||
|             defer.returnValue([ | ||||
|                 format_user_presence_state(state, now) for state in updates | ||||
|             ]) | ||||
|             defer.returnValue(updates) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def set_state(self, target_user, state, ignore_status_msg=False): | ||||
| @ -795,6 +793,9 @@ class PresenceHandler(object): | ||||
|             as_event=False, | ||||
|         ) | ||||
| 
 | ||||
|         now = self.clock.time_msec() | ||||
|         results[:] = [format_user_presence_state(r, now) for r in results] | ||||
| 
 | ||||
|         is_accepted = { | ||||
|             row["observed_user_id"]: row["accepted"] for row in presence_list | ||||
|         } | ||||
| @ -847,6 +848,7 @@ class PresenceHandler(object): | ||||
|             ) | ||||
| 
 | ||||
|             state_dict = yield self.get_state(observed_user, as_event=False) | ||||
|             state_dict = format_user_presence_state(state_dict, self.clock.time_msec()) | ||||
| 
 | ||||
|             self.federation.send_edu( | ||||
|                 destination=observer_user.domain, | ||||
| @ -910,11 +912,12 @@ class PresenceHandler(object): | ||||
|     def is_visible(self, observed_user, observer_user): | ||||
|         """Returns whether a user can see another user's presence. | ||||
|         """ | ||||
|         observer_rooms = yield self.store.get_rooms_for_user(observer_user.to_string()) | ||||
|         observed_rooms = yield self.store.get_rooms_for_user(observed_user.to_string()) | ||||
| 
 | ||||
|         observer_room_ids = set(r.room_id for r in observer_rooms) | ||||
|         observed_room_ids = set(r.room_id for r in observed_rooms) | ||||
|         observer_room_ids = yield self.store.get_rooms_for_user( | ||||
|             observer_user.to_string() | ||||
|         ) | ||||
|         observed_room_ids = yield self.store.get_rooms_for_user( | ||||
|             observed_user.to_string() | ||||
|         ) | ||||
| 
 | ||||
|         if observer_room_ids & observed_room_ids: | ||||
|             defer.returnValue(True) | ||||
| @ -979,14 +982,18 @@ def should_notify(old_state, new_state): | ||||
|     return False | ||||
| 
 | ||||
| 
 | ||||
| def format_user_presence_state(state, now): | ||||
| def format_user_presence_state(state, now, include_user_id=True): | ||||
|     """Convert UserPresenceState to a format that can be sent down to clients | ||||
|     and to other servers. | ||||
| 
 | ||||
|     The "user_id" is optional so that this function can be used to format presence | ||||
|     updates for client /sync responses and for federation /send requests. | ||||
|     """ | ||||
|     content = { | ||||
|         "presence": state.state, | ||||
|         "user_id": state.user_id, | ||||
|     } | ||||
|     if include_user_id: | ||||
|         content["user_id"] = state.user_id | ||||
|     if state.last_active_ts: | ||||
|         content["last_active_ago"] = now - state.last_active_ts | ||||
|     if state.status_msg and state.state != PresenceState.OFFLINE: | ||||
| @ -1025,7 +1032,6 @@ class PresenceEventSource(object): | ||||
|         # sending down the rare duplicate is not a concern. | ||||
| 
 | ||||
|         with Measure(self.clock, "presence.get_new_events"): | ||||
|             user_id = user.to_string() | ||||
|             if from_key is not None: | ||||
|                 from_key = int(from_key) | ||||
| 
 | ||||
| @ -1034,18 +1040,7 @@ class PresenceEventSource(object): | ||||
| 
 | ||||
|             max_token = self.store.get_current_presence_token() | ||||
| 
 | ||||
|             plist = yield self.store.get_presence_list_accepted(user.localpart) | ||||
|             users_interested_in = set(row["observed_user_id"] for row in plist) | ||||
|             users_interested_in.add(user_id)  # So that we receive our own presence | ||||
| 
 | ||||
|             users_who_share_room = yield self.store.get_users_who_share_room_with_user( | ||||
|                 user_id | ||||
|             ) | ||||
|             users_interested_in.update(users_who_share_room) | ||||
| 
 | ||||
|             if explicit_room_id: | ||||
|                 user_ids = yield self.store.get_users_in_room(explicit_room_id) | ||||
|                 users_interested_in.update(user_ids) | ||||
|             users_interested_in = yield self._get_interested_in(user, explicit_room_id) | ||||
| 
 | ||||
|             user_ids_changed = set() | ||||
|             changed = None | ||||
| @ -1073,16 +1068,13 @@ class PresenceEventSource(object): | ||||
| 
 | ||||
|             updates = yield presence.current_state_for_users(user_ids_changed) | ||||
| 
 | ||||
|         now = self.clock.time_msec() | ||||
| 
 | ||||
|         defer.returnValue(([ | ||||
|             { | ||||
|                 "type": "m.presence", | ||||
|                 "content": format_user_presence_state(s, now), | ||||
|             } | ||||
|             for s in updates.values() | ||||
|             if include_offline or s.state != PresenceState.OFFLINE | ||||
|         ], max_token)) | ||||
|         if include_offline: | ||||
|             defer.returnValue((updates.values(), max_token)) | ||||
|         else: | ||||
|             defer.returnValue(([ | ||||
|                 s for s in updates.itervalues() | ||||
|                 if s.state != PresenceState.OFFLINE | ||||
|             ], max_token)) | ||||
| 
 | ||||
|     def get_current_key(self): | ||||
|         return self.store.get_current_presence_token() | ||||
| @ -1090,6 +1082,31 @@ class PresenceEventSource(object): | ||||
|     def get_pagination_rows(self, user, pagination_config, key): | ||||
|         return self.get_new_events(user, from_key=None, include_offline=False) | ||||
| 
 | ||||
|     @cachedInlineCallbacks(num_args=2, cache_context=True) | ||||
|     def _get_interested_in(self, user, explicit_room_id, cache_context): | ||||
|         """Returns the set of users that the given user should see presence | ||||
|         updates for | ||||
|         """ | ||||
|         user_id = user.to_string() | ||||
|         plist = yield self.store.get_presence_list_accepted( | ||||
|             user.localpart, on_invalidate=cache_context.invalidate, | ||||
|         ) | ||||
|         users_interested_in = set(row["observed_user_id"] for row in plist) | ||||
|         users_interested_in.add(user_id)  # So that we receive our own presence | ||||
| 
 | ||||
|         users_who_share_room = yield self.store.get_users_who_share_room_with_user( | ||||
|             user_id, on_invalidate=cache_context.invalidate, | ||||
|         ) | ||||
|         users_interested_in.update(users_who_share_room) | ||||
| 
 | ||||
|         if explicit_room_id: | ||||
|             user_ids = yield self.store.get_users_in_room( | ||||
|                 explicit_room_id, on_invalidate=cache_context.invalidate, | ||||
|             ) | ||||
|             users_interested_in.update(user_ids) | ||||
| 
 | ||||
|         defer.returnValue(users_interested_in) | ||||
| 
 | ||||
| 
 | ||||
| def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now): | ||||
|     """Checks the presence of users that have timed out and updates as | ||||
| @ -1157,7 +1174,10 @@ def handle_timeout(state, is_mine, syncing_user_ids, now): | ||||
|         # If there are have been no sync for a while (and none ongoing), | ||||
|         # set presence to offline | ||||
|         if user_id not in syncing_user_ids: | ||||
|             if now - state.last_user_sync_ts > SYNC_ONLINE_TIMEOUT: | ||||
|             # If the user has done something recently but hasn't synced, | ||||
|             # don't set them as offline. | ||||
|             sync_or_active = max(state.last_user_sync_ts, state.last_active_ts) | ||||
|             if now - sync_or_active > SYNC_ONLINE_TIMEOUT: | ||||
|                 state = state.copy_and_replace( | ||||
|                     state=PresenceState.OFFLINE, | ||||
|                     status_msg=None, | ||||
|  | ||||
| @ -52,7 +52,8 @@ class ProfileHandler(BaseHandler): | ||||
|                     args={ | ||||
|                         "user_id": target_user.to_string(), | ||||
|                         "field": "displayname", | ||||
|                     } | ||||
|                     }, | ||||
|                     ignore_backoff=True, | ||||
|                 ) | ||||
|             except CodeMessageException as e: | ||||
|                 if e.code != 404: | ||||
| @ -99,7 +100,8 @@ class ProfileHandler(BaseHandler): | ||||
|                     args={ | ||||
|                         "user_id": target_user.to_string(), | ||||
|                         "field": "avatar_url", | ||||
|                     } | ||||
|                     }, | ||||
|                     ignore_backoff=True, | ||||
|                 ) | ||||
|             except CodeMessageException as e: | ||||
|                 if e.code != 404: | ||||
| @ -156,11 +158,11 @@ class ProfileHandler(BaseHandler): | ||||
| 
 | ||||
|         self.ratelimit(requester) | ||||
| 
 | ||||
|         joins = yield self.store.get_rooms_for_user( | ||||
|         room_ids = yield self.store.get_rooms_for_user( | ||||
|             user.to_string(), | ||||
|         ) | ||||
| 
 | ||||
|         for j in joins: | ||||
|         for room_id in room_ids: | ||||
|             handler = self.hs.get_handlers().room_member_handler | ||||
|             try: | ||||
|                 # Assume the user isn't a guest because we don't let guests set | ||||
| @ -171,12 +173,12 @@ class ProfileHandler(BaseHandler): | ||||
|                 yield handler.update_membership( | ||||
|                     requester, | ||||
|                     user, | ||||
|                     j.room_id, | ||||
|                     room_id, | ||||
|                     "join",  # We treat a profile update like a join. | ||||
|                     ratelimit=False,  # Try to hide that these events aren't atomic. | ||||
|                 ) | ||||
|             except Exception as e: | ||||
|                 logger.warn( | ||||
|                     "Failed to update join event for room %s - %s", | ||||
|                     j.room_id, str(e.message) | ||||
|                     room_id, str(e.message) | ||||
|                 ) | ||||
|  | ||||
| @ -210,10 +210,9 @@ class ReceiptEventSource(object): | ||||
|         else: | ||||
|             from_key = None | ||||
| 
 | ||||
|         rooms = yield self.store.get_rooms_for_user(user.to_string()) | ||||
|         rooms = [room.room_id for room in rooms] | ||||
|         room_ids = yield self.store.get_rooms_for_user(user.to_string()) | ||||
|         events = yield self.store.get_linearized_receipts_for_rooms( | ||||
|             rooms, | ||||
|             room_ids, | ||||
|             from_key=from_key, | ||||
|             to_key=to_key, | ||||
|         ) | ||||
|  | ||||
| @ -21,6 +21,7 @@ from synapse.api.constants import ( | ||||
|     EventTypes, JoinRules, | ||||
| ) | ||||
| from synapse.util.async import concurrently_execute | ||||
| from synapse.util.caches.descriptors import cachedInlineCallbacks | ||||
| from synapse.util.caches.response_cache import ResponseCache | ||||
| from synapse.types import ThirdPartyInstanceID | ||||
| 
 | ||||
| @ -62,6 +63,10 @@ class RoomListHandler(BaseHandler): | ||||
|                 appservice and network id to use an appservice specific one. | ||||
|                 Setting to None returns all public rooms across all lists. | ||||
|         """ | ||||
|         logger.info( | ||||
|             "Getting public room list: limit=%r, since=%r, search=%r, network=%r", | ||||
|             limit, since_token, bool(search_filter), network_tuple, | ||||
|         ) | ||||
|         if search_filter: | ||||
|             # We explicitly don't bother caching searches or requests for | ||||
|             # appservice specific lists. | ||||
| @ -91,7 +96,6 @@ class RoomListHandler(BaseHandler): | ||||
| 
 | ||||
|         rooms_to_order_value = {} | ||||
|         rooms_to_num_joined = {} | ||||
|         rooms_to_latest_event_ids = {} | ||||
| 
 | ||||
|         newly_visible = [] | ||||
|         newly_unpublished = [] | ||||
| @ -116,19 +120,26 @@ class RoomListHandler(BaseHandler): | ||||
| 
 | ||||
|         @defer.inlineCallbacks | ||||
|         def get_order_for_room(room_id): | ||||
|             latest_event_ids = rooms_to_latest_event_ids.get(room_id, None) | ||||
|             if not latest_event_ids: | ||||
|             # Most of the rooms won't have changed between the since token and | ||||
|             # now (especially if the since token is "now"). So, we can ask what | ||||
|             # the current users are in a room (that will hit a cache) and then | ||||
|             # check if the room has changed since the since token. (We have to | ||||
|             # do it in that order to avoid races). | ||||
|             # If things have changed then fall back to getting the current state | ||||
|             # at the since token. | ||||
|             joined_users = yield self.store.get_users_in_room(room_id) | ||||
|             if self.store.has_room_changed_since(room_id, stream_token): | ||||
|                 latest_event_ids = yield self.store.get_forward_extremeties_for_room( | ||||
|                     room_id, stream_token | ||||
|                 ) | ||||
|                 rooms_to_latest_event_ids[room_id] = latest_event_ids | ||||
| 
 | ||||
|             if not latest_event_ids: | ||||
|                 return | ||||
|                 if not latest_event_ids: | ||||
|                     return | ||||
| 
 | ||||
|                 joined_users = yield self.state_handler.get_current_user_in_room( | ||||
|                     room_id, latest_event_ids, | ||||
|                 ) | ||||
| 
 | ||||
|             joined_users = yield self.state_handler.get_current_user_in_room( | ||||
|                 room_id, latest_event_ids, | ||||
|             ) | ||||
|             num_joined_users = len(joined_users) | ||||
|             rooms_to_num_joined[room_id] = num_joined_users | ||||
| 
 | ||||
| @ -165,19 +176,19 @@ class RoomListHandler(BaseHandler): | ||||
|                 rooms_to_scan = rooms_to_scan[:since_token.current_limit] | ||||
|                 rooms_to_scan.reverse() | ||||
| 
 | ||||
|         # Actually generate the entries. _generate_room_entry will append to | ||||
|         # Actually generate the entries. _append_room_entry_to_chunk will append to | ||||
|         # chunk but will stop if len(chunk) > limit | ||||
|         chunk = [] | ||||
|         if limit and not search_filter: | ||||
|             step = limit + 1 | ||||
|             for i in xrange(0, len(rooms_to_scan), step): | ||||
|                 # We iterate here because the vast majority of cases we'll stop | ||||
|                 # at first iteration, but occaisonally _generate_room_entry | ||||
|                 # at first iteration, but occaisonally _append_room_entry_to_chunk | ||||
|                 # won't append to the chunk and so we need to loop again. | ||||
|                 # We don't want to scan over the entire range either as that | ||||
|                 # would potentially waste a lot of work. | ||||
|                 yield concurrently_execute( | ||||
|                     lambda r: self._generate_room_entry( | ||||
|                     lambda r: self._append_room_entry_to_chunk( | ||||
|                         r, rooms_to_num_joined[r], | ||||
|                         chunk, limit, search_filter | ||||
|                     ), | ||||
| @ -187,7 +198,7 @@ class RoomListHandler(BaseHandler): | ||||
|                     break | ||||
|         else: | ||||
|             yield concurrently_execute( | ||||
|                 lambda r: self._generate_room_entry( | ||||
|                 lambda r: self._append_room_entry_to_chunk( | ||||
|                     r, rooms_to_num_joined[r], | ||||
|                     chunk, limit, search_filter | ||||
|                 ), | ||||
| @ -256,21 +267,35 @@ class RoomListHandler(BaseHandler): | ||||
|         defer.returnValue(results) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _generate_room_entry(self, room_id, num_joined_users, chunk, limit, | ||||
|                              search_filter): | ||||
|     def _append_room_entry_to_chunk(self, room_id, num_joined_users, chunk, limit, | ||||
|                                     search_filter): | ||||
|         """Generate the entry for a room in the public room list and append it | ||||
|         to the `chunk` if it matches the search filter | ||||
|         """ | ||||
|         if limit and len(chunk) > limit + 1: | ||||
|             # We've already got enough, so lets just drop it. | ||||
|             return | ||||
| 
 | ||||
|         result = yield self._generate_room_entry(room_id, num_joined_users) | ||||
| 
 | ||||
|         if result and _matches_room_entry(result, search_filter): | ||||
|             chunk.append(result) | ||||
| 
 | ||||
|     @cachedInlineCallbacks(num_args=1, cache_context=True) | ||||
|     def _generate_room_entry(self, room_id, num_joined_users, cache_context): | ||||
|         """Returns the entry for a room | ||||
|         """ | ||||
|         result = { | ||||
|             "room_id": room_id, | ||||
|             "num_joined_members": num_joined_users, | ||||
|         } | ||||
| 
 | ||||
|         current_state_ids = yield self.state_handler.get_current_state_ids(room_id) | ||||
|         current_state_ids = yield self.store.get_current_state_ids( | ||||
|             room_id, on_invalidate=cache_context.invalidate, | ||||
|         ) | ||||
| 
 | ||||
|         event_map = yield self.store.get_events([ | ||||
|             event_id for key, event_id in current_state_ids.items() | ||||
|             event_id for key, event_id in current_state_ids.iteritems() | ||||
|             if key[0] in ( | ||||
|                 EventTypes.JoinRules, | ||||
|                 EventTypes.Name, | ||||
| @ -294,7 +319,9 @@ class RoomListHandler(BaseHandler): | ||||
|             if join_rule and join_rule != JoinRules.PUBLIC: | ||||
|                 defer.returnValue(None) | ||||
| 
 | ||||
|         aliases = yield self.store.get_aliases_for_room(room_id) | ||||
|         aliases = yield self.store.get_aliases_for_room( | ||||
|             room_id, on_invalidate=cache_context.invalidate | ||||
|         ) | ||||
|         if aliases: | ||||
|             result["aliases"] = aliases | ||||
| 
 | ||||
| @ -334,8 +361,7 @@ class RoomListHandler(BaseHandler): | ||||
|             if avatar_url: | ||||
|                 result["avatar_url"] = avatar_url | ||||
| 
 | ||||
|         if _matches_room_entry(result, search_filter): | ||||
|             chunk.append(result) | ||||
|         defer.returnValue(result) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_remote_public_room_list(self, server_name, limit=None, since_token=None, | ||||
|  | ||||
| @ -20,6 +20,7 @@ from synapse.util.metrics import Measure, measure_func | ||||
| from synapse.util.caches.response_cache import ResponseCache | ||||
| from synapse.push.clientformat import format_push_rules_for_user | ||||
| from synapse.visibility import filter_events_for_client | ||||
| from synapse.types import RoomStreamToken | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| @ -225,8 +226,7 @@ class SyncHandler(object): | ||||
|         with Measure(self.clock, "ephemeral_by_room"): | ||||
|             typing_key = since_token.typing_key if since_token else "0" | ||||
| 
 | ||||
|             rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string()) | ||||
|             room_ids = [room.room_id for room in rooms] | ||||
|             room_ids = yield self.store.get_rooms_for_user(sync_config.user.to_string()) | ||||
| 
 | ||||
|             typing_source = self.event_sources.sources["typing"] | ||||
|             typing, typing_key = yield typing_source.get_new_events( | ||||
| @ -568,16 +568,15 @@ class SyncHandler(object): | ||||
|         since_token = sync_result_builder.since_token | ||||
| 
 | ||||
|         if since_token and since_token.device_list_key: | ||||
|             rooms = yield self.store.get_rooms_for_user(user_id) | ||||
|             room_ids = set(r.room_id for r in rooms) | ||||
|             room_ids = yield self.store.get_rooms_for_user(user_id) | ||||
| 
 | ||||
|             user_ids_changed = set() | ||||
|             changed = yield self.store.get_user_whose_devices_changed( | ||||
|                 since_token.device_list_key | ||||
|             ) | ||||
|             for other_user_id in changed: | ||||
|                 other_rooms = yield self.store.get_rooms_for_user(other_user_id) | ||||
|                 if room_ids.intersection(e.room_id for e in other_rooms): | ||||
|                 other_room_ids = yield self.store.get_rooms_for_user(other_user_id) | ||||
|                 if room_ids.intersection(other_room_ids): | ||||
|                     user_ids_changed.add(other_user_id) | ||||
| 
 | ||||
|             defer.returnValue(user_ids_changed) | ||||
| @ -721,14 +720,14 @@ class SyncHandler(object): | ||||
|             extra_users_ids.update(users) | ||||
|         extra_users_ids.discard(user.to_string()) | ||||
| 
 | ||||
|         states = yield self.presence_handler.get_states( | ||||
|             extra_users_ids, | ||||
|             as_event=True, | ||||
|         ) | ||||
|         presence.extend(states) | ||||
|         if extra_users_ids: | ||||
|             states = yield self.presence_handler.get_states( | ||||
|                 extra_users_ids, | ||||
|             ) | ||||
|             presence.extend(states) | ||||
| 
 | ||||
|         # Deduplicate the presence entries so that there's at most one per user | ||||
|         presence = {p["content"]["user_id"]: p for p in presence}.values() | ||||
|             # Deduplicate the presence entries so that there's at most one per user | ||||
|             presence = {p.user_id: p for p in presence}.values() | ||||
| 
 | ||||
|         presence = sync_config.filter_collection.filter_presence( | ||||
|             presence | ||||
| @ -765,6 +764,21 @@ class SyncHandler(object): | ||||
|             ) | ||||
|             sync_result_builder.now_token = now_token | ||||
| 
 | ||||
|         # We check up front if anything has changed, if it hasn't then there is | ||||
|         # no point in going futher. | ||||
|         since_token = sync_result_builder.since_token | ||||
|         if not sync_result_builder.full_state: | ||||
|             if since_token and not ephemeral_by_room and not account_data_by_room: | ||||
|                 have_changed = yield self._have_rooms_changed(sync_result_builder) | ||||
|                 if not have_changed: | ||||
|                     tags_by_room = yield self.store.get_updated_tags( | ||||
|                         user_id, | ||||
|                         since_token.account_data_key, | ||||
|                     ) | ||||
|                     if not tags_by_room: | ||||
|                         logger.debug("no-oping sync") | ||||
|                         defer.returnValue(([], [])) | ||||
| 
 | ||||
|         ignored_account_data = yield self.store.get_global_account_data_by_type_for_user( | ||||
|             "m.ignored_user_list", user_id=user_id, | ||||
|         ) | ||||
| @ -774,13 +788,12 @@ class SyncHandler(object): | ||||
|         else: | ||||
|             ignored_users = frozenset() | ||||
| 
 | ||||
|         if sync_result_builder.since_token: | ||||
|         if since_token: | ||||
|             res = yield self._get_rooms_changed(sync_result_builder, ignored_users) | ||||
|             room_entries, invited, newly_joined_rooms = res | ||||
| 
 | ||||
|             tags_by_room = yield self.store.get_updated_tags( | ||||
|                 user_id, | ||||
|                 sync_result_builder.since_token.account_data_key, | ||||
|                 user_id, since_token.account_data_key, | ||||
|             ) | ||||
|         else: | ||||
|             res = yield self._get_all_rooms(sync_result_builder, ignored_users) | ||||
| @ -805,7 +818,7 @@ class SyncHandler(object): | ||||
| 
 | ||||
|         # Now we want to get any newly joined users | ||||
|         newly_joined_users = set() | ||||
|         if sync_result_builder.since_token: | ||||
|         if since_token: | ||||
|             for joined_sync in sync_result_builder.joined: | ||||
|                 it = itertools.chain( | ||||
|                     joined_sync.timeline.events, joined_sync.state.values() | ||||
| @ -817,6 +830,38 @@ class SyncHandler(object): | ||||
| 
 | ||||
|         defer.returnValue((newly_joined_rooms, newly_joined_users)) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _have_rooms_changed(self, sync_result_builder): | ||||
|         """Returns whether there may be any new events that should be sent down | ||||
|         the sync. Returns True if there are. | ||||
|         """ | ||||
|         user_id = sync_result_builder.sync_config.user.to_string() | ||||
|         since_token = sync_result_builder.since_token | ||||
|         now_token = sync_result_builder.now_token | ||||
| 
 | ||||
|         assert since_token | ||||
| 
 | ||||
|         # Get a list of membership change events that have happened. | ||||
|         rooms_changed = yield self.store.get_membership_changes_for_user( | ||||
|             user_id, since_token.room_key, now_token.room_key | ||||
|         ) | ||||
| 
 | ||||
|         if rooms_changed: | ||||
|             defer.returnValue(True) | ||||
| 
 | ||||
|         app_service = self.store.get_app_service_by_user_id(user_id) | ||||
|         if app_service: | ||||
|             rooms = yield self.store.get_app_service_rooms(app_service) | ||||
|             joined_room_ids = set(r.room_id for r in rooms) | ||||
|         else: | ||||
|             joined_room_ids = yield self.store.get_rooms_for_user(user_id) | ||||
| 
 | ||||
|         stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream | ||||
|         for room_id in joined_room_ids: | ||||
|             if self.store.has_room_changed_since(room_id, stream_id): | ||||
|                 defer.returnValue(True) | ||||
|         defer.returnValue(False) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _get_rooms_changed(self, sync_result_builder, ignored_users): | ||||
|         """Gets the the changes that have happened since the last sync. | ||||
| @ -841,8 +886,7 @@ class SyncHandler(object): | ||||
|             rooms = yield self.store.get_app_service_rooms(app_service) | ||||
|             joined_room_ids = set(r.room_id for r in rooms) | ||||
|         else: | ||||
|             rooms = yield self.store.get_rooms_for_user(user_id) | ||||
|             joined_room_ids = set(r.room_id for r in rooms) | ||||
|             joined_room_ids = yield self.store.get_rooms_for_user(user_id) | ||||
| 
 | ||||
|         # Get a list of membership change events that have happened. | ||||
|         rooms_changed = yield self.store.get_membership_changes_for_user( | ||||
|  | ||||
| @ -12,8 +12,7 @@ | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| 
 | ||||
| import synapse.util.retryutils | ||||
| from twisted.internet import defer, reactor, protocol | ||||
| from twisted.internet.error import DNSLookupError | ||||
| from twisted.web.client import readBody, HTTPConnectionPool, Agent | ||||
| @ -22,7 +21,7 @@ from twisted.web._newclient import ResponseDone | ||||
| 
 | ||||
| from synapse.http.endpoint import matrix_federation_endpoint | ||||
| from synapse.util.async import sleep | ||||
| from synapse.util.logcontext import preserve_context_over_fn | ||||
| from synapse.util import logcontext | ||||
| import synapse.metrics | ||||
| 
 | ||||
| from canonicaljson import encode_canonical_json | ||||
| @ -94,6 +93,7 @@ class MatrixFederationHttpClient(object): | ||||
|             reactor, MatrixFederationEndpointFactory(hs), pool=pool | ||||
|         ) | ||||
|         self.clock = hs.get_clock() | ||||
|         self._store = hs.get_datastore() | ||||
|         self.version_string = hs.version_string | ||||
|         self._next_id = 1 | ||||
| 
 | ||||
| @ -103,123 +103,152 @@ class MatrixFederationHttpClient(object): | ||||
|         ) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _create_request(self, destination, method, path_bytes, | ||||
|                         body_callback, headers_dict={}, param_bytes=b"", | ||||
|                         query_bytes=b"", retry_on_dns_fail=True, | ||||
|                         timeout=None, long_retries=False): | ||||
|         """ Creates and sends a request to the given url | ||||
|     def _request(self, destination, method, path, | ||||
|                  body_callback, headers_dict={}, param_bytes=b"", | ||||
|                  query_bytes=b"", retry_on_dns_fail=True, | ||||
|                  timeout=None, long_retries=False, | ||||
|                  ignore_backoff=False, | ||||
|                  backoff_on_404=False): | ||||
|         """ Creates and sends a request to the given server | ||||
|         Args: | ||||
|             destination (str): The remote server to send the HTTP request to. | ||||
|             method (str): HTTP method | ||||
|             path (str): The HTTP path | ||||
|             ignore_backoff (bool): true to ignore the historical backoff data | ||||
|                 and try the request anyway. | ||||
|             backoff_on_404 (bool): Back off if we get a 404 | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred: resolves with the http response object on success. | ||||
| 
 | ||||
|             Fails with ``HTTPRequestException``: if we get an HTTP response | ||||
|                 code >= 300. | ||||
|             Fails with ``NotRetryingDestination`` if we are not yet ready | ||||
|                 to retry this server. | ||||
|         """ | ||||
|         headers_dict[b"User-Agent"] = [self.version_string] | ||||
|         headers_dict[b"Host"] = [destination] | ||||
| 
 | ||||
|         url_bytes = self._create_url( | ||||
|             destination, path_bytes, param_bytes, query_bytes | ||||
|         limiter = yield synapse.util.retryutils.get_retry_limiter( | ||||
|             destination, | ||||
|             self.clock, | ||||
|             self._store, | ||||
|             backoff_on_404=backoff_on_404, | ||||
|             ignore_backoff=ignore_backoff, | ||||
|         ) | ||||
| 
 | ||||
|         txn_id = "%s-O-%s" % (method, self._next_id) | ||||
|         self._next_id = (self._next_id + 1) % (sys.maxint - 1) | ||||
|         destination = destination.encode("ascii") | ||||
|         path_bytes = path.encode("ascii") | ||||
|         with limiter: | ||||
|             headers_dict[b"User-Agent"] = [self.version_string] | ||||
|             headers_dict[b"Host"] = [destination] | ||||
| 
 | ||||
|         outbound_logger.info( | ||||
|             "{%s} [%s] Sending request: %s %s", | ||||
|             txn_id, destination, method, url_bytes | ||||
|         ) | ||||
|             url_bytes = self._create_url( | ||||
|                 destination, path_bytes, param_bytes, query_bytes | ||||
|             ) | ||||
| 
 | ||||
|         # XXX: Would be much nicer to retry only at the transaction-layer | ||||
|         # (once we have reliable transactions in place) | ||||
|         if long_retries: | ||||
|             retries_left = MAX_LONG_RETRIES | ||||
|         else: | ||||
|             retries_left = MAX_SHORT_RETRIES | ||||
|             txn_id = "%s-O-%s" % (method, self._next_id) | ||||
|             self._next_id = (self._next_id + 1) % (sys.maxint - 1) | ||||
| 
 | ||||
|         http_url_bytes = urlparse.urlunparse( | ||||
|             ("", "", path_bytes, param_bytes, query_bytes, "") | ||||
|         ) | ||||
|             outbound_logger.info( | ||||
|                 "{%s} [%s] Sending request: %s %s", | ||||
|                 txn_id, destination, method, url_bytes | ||||
|             ) | ||||
| 
 | ||||
|         log_result = None | ||||
|         try: | ||||
|             while True: | ||||
|                 producer = None | ||||
|                 if body_callback: | ||||
|                     producer = body_callback(method, http_url_bytes, headers_dict) | ||||
|             # XXX: Would be much nicer to retry only at the transaction-layer | ||||
|             # (once we have reliable transactions in place) | ||||
|             if long_retries: | ||||
|                 retries_left = MAX_LONG_RETRIES | ||||
|             else: | ||||
|                 retries_left = MAX_SHORT_RETRIES | ||||
| 
 | ||||
|                 try: | ||||
|                     def send_request(): | ||||
|                         request_deferred = preserve_context_over_fn( | ||||
|                             self.agent.request, | ||||
|             http_url_bytes = urlparse.urlunparse( | ||||
|                 ("", "", path_bytes, param_bytes, query_bytes, "") | ||||
|             ) | ||||
| 
 | ||||
|             log_result = None | ||||
|             try: | ||||
|                 while True: | ||||
|                     producer = None | ||||
|                     if body_callback: | ||||
|                         producer = body_callback(method, http_url_bytes, headers_dict) | ||||
| 
 | ||||
|                     try: | ||||
|                         def send_request(): | ||||
|                             request_deferred = self.agent.request( | ||||
|                                 method, | ||||
|                                 url_bytes, | ||||
|                                 Headers(headers_dict), | ||||
|                                 producer | ||||
|                             ) | ||||
| 
 | ||||
|                             return self.clock.time_bound_deferred( | ||||
|                                 request_deferred, | ||||
|                                 time_out=timeout / 1000. if timeout else 60, | ||||
|                             ) | ||||
| 
 | ||||
|                         with logcontext.PreserveLoggingContext(): | ||||
|                             response = yield send_request() | ||||
| 
 | ||||
|                         log_result = "%d %s" % (response.code, response.phrase,) | ||||
|                         break | ||||
|                     except Exception as e: | ||||
|                         if not retry_on_dns_fail and isinstance(e, DNSLookupError): | ||||
|                             logger.warn( | ||||
|                                 "DNS Lookup failed to %s with %s", | ||||
|                                 destination, | ||||
|                                 e | ||||
|                             ) | ||||
|                             log_result = "DNS Lookup failed to %s with %s" % ( | ||||
|                                 destination, e | ||||
|                             ) | ||||
|                             raise | ||||
| 
 | ||||
|                         logger.warn( | ||||
|                             "{%s} Sending request failed to %s: %s %s: %s - %s", | ||||
|                             txn_id, | ||||
|                             destination, | ||||
|                             method, | ||||
|                             url_bytes, | ||||
|                             Headers(headers_dict), | ||||
|                             producer | ||||
|                             type(e).__name__, | ||||
|                             _flatten_response_never_received(e), | ||||
|                         ) | ||||
| 
 | ||||
|                         return self.clock.time_bound_deferred( | ||||
|                             request_deferred, | ||||
|                             time_out=timeout / 1000. if timeout else 60, | ||||
|                         log_result = "%s - %s" % ( | ||||
|                             type(e).__name__, _flatten_response_never_received(e), | ||||
|                         ) | ||||
| 
 | ||||
|                     response = yield preserve_context_over_fn(send_request) | ||||
|                         if retries_left and not timeout: | ||||
|                             if long_retries: | ||||
|                                 delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left) | ||||
|                                 delay = min(delay, 60) | ||||
|                                 delay *= random.uniform(0.8, 1.4) | ||||
|                             else: | ||||
|                                 delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left) | ||||
|                                 delay = min(delay, 2) | ||||
|                                 delay *= random.uniform(0.8, 1.4) | ||||
| 
 | ||||
|                     log_result = "%d %s" % (response.code, response.phrase,) | ||||
|                     break | ||||
|                 except Exception as e: | ||||
|                     if not retry_on_dns_fail and isinstance(e, DNSLookupError): | ||||
|                         logger.warn( | ||||
|                             "DNS Lookup failed to %s with %s", | ||||
|                             destination, | ||||
|                             e | ||||
|                         ) | ||||
|                         log_result = "DNS Lookup failed to %s with %s" % ( | ||||
|                             destination, e | ||||
|                         ) | ||||
|                         raise | ||||
| 
 | ||||
|                     logger.warn( | ||||
|                         "{%s} Sending request failed to %s: %s %s: %s - %s", | ||||
|                         txn_id, | ||||
|                         destination, | ||||
|                         method, | ||||
|                         url_bytes, | ||||
|                         type(e).__name__, | ||||
|                         _flatten_response_never_received(e), | ||||
|                     ) | ||||
| 
 | ||||
|                     log_result = "%s - %s" % ( | ||||
|                         type(e).__name__, _flatten_response_never_received(e), | ||||
|                     ) | ||||
| 
 | ||||
|                     if retries_left and not timeout: | ||||
|                         if long_retries: | ||||
|                             delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left) | ||||
|                             delay = min(delay, 60) | ||||
|                             delay *= random.uniform(0.8, 1.4) | ||||
|                             yield sleep(delay) | ||||
|                             retries_left -= 1 | ||||
|                         else: | ||||
|                             delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left) | ||||
|                             delay = min(delay, 2) | ||||
|                             delay *= random.uniform(0.8, 1.4) | ||||
|                             raise | ||||
|             finally: | ||||
|                 outbound_logger.info( | ||||
|                     "{%s} [%s] Result: %s", | ||||
|                     txn_id, | ||||
|                     destination, | ||||
|                     log_result, | ||||
|                 ) | ||||
| 
 | ||||
|                         yield sleep(delay) | ||||
|                         retries_left -= 1 | ||||
|                     else: | ||||
|                         raise | ||||
|         finally: | ||||
|             outbound_logger.info( | ||||
|                 "{%s} [%s] Result: %s", | ||||
|                 txn_id, | ||||
|                 destination, | ||||
|                 log_result, | ||||
|             ) | ||||
|             if 200 <= response.code < 300: | ||||
|                 pass | ||||
|             else: | ||||
|                 # :'( | ||||
|                 # Update transactions table? | ||||
|                 with logcontext.PreserveLoggingContext(): | ||||
|                     body = yield readBody(response) | ||||
|                 raise HttpResponseException( | ||||
|                     response.code, response.phrase, body | ||||
|                 ) | ||||
| 
 | ||||
|         if 200 <= response.code < 300: | ||||
|             pass | ||||
|         else: | ||||
|             # :'( | ||||
|             # Update transactions table? | ||||
|             body = yield preserve_context_over_fn(readBody, response) | ||||
|             raise HttpResponseException( | ||||
|                 response.code, response.phrase, body | ||||
|             ) | ||||
| 
 | ||||
|         defer.returnValue(response) | ||||
|             defer.returnValue(response) | ||||
| 
 | ||||
|     def sign_request(self, destination, method, url_bytes, headers_dict, | ||||
|                      content=None): | ||||
| @ -248,7 +277,9 @@ class MatrixFederationHttpClient(object): | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def put_json(self, destination, path, data={}, json_data_callback=None, | ||||
|                  long_retries=False, timeout=None): | ||||
|                  long_retries=False, timeout=None, | ||||
|                  ignore_backoff=False, | ||||
|                  backoff_on_404=False): | ||||
|         """ Sends the specifed json data using PUT | ||||
| 
 | ||||
|         Args: | ||||
| @ -263,11 +294,19 @@ class MatrixFederationHttpClient(object): | ||||
|                 retry for a short or long time. | ||||
|             timeout(int): How long to try (in ms) the destination for before | ||||
|                 giving up. None indicates no timeout. | ||||
|             ignore_backoff (bool): true to ignore the historical backoff data | ||||
|                 and try the request anyway. | ||||
|             backoff_on_404 (bool): True if we should count a 404 response as | ||||
|                 a failure of the server (and should therefore back off future | ||||
|                 requests) | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred: Succeeds when we get a 2xx HTTP response. The result | ||||
|             will be the decoded JSON body. On a 4xx or 5xx error response a | ||||
|             CodeMessageException is raised. | ||||
| 
 | ||||
|             Fails with ``NotRetryingDestination`` if we are not yet ready | ||||
|             to retry this server. | ||||
|         """ | ||||
| 
 | ||||
|         if not json_data_callback: | ||||
| @ -282,26 +321,29 @@ class MatrixFederationHttpClient(object): | ||||
|             producer = _JsonProducer(json_data) | ||||
|             return producer | ||||
| 
 | ||||
|         response = yield self._create_request( | ||||
|             destination.encode("ascii"), | ||||
|         response = yield self._request( | ||||
|             destination, | ||||
|             "PUT", | ||||
|             path.encode("ascii"), | ||||
|             path, | ||||
|             body_callback=body_callback, | ||||
|             headers_dict={"Content-Type": ["application/json"]}, | ||||
|             long_retries=long_retries, | ||||
|             timeout=timeout, | ||||
|             ignore_backoff=ignore_backoff, | ||||
|             backoff_on_404=backoff_on_404, | ||||
|         ) | ||||
| 
 | ||||
|         if 200 <= response.code < 300: | ||||
|             # We need to update the transactions table to say it was sent? | ||||
|             check_content_type_is_json(response.headers) | ||||
| 
 | ||||
|         body = yield preserve_context_over_fn(readBody, response) | ||||
|         with logcontext.PreserveLoggingContext(): | ||||
|             body = yield readBody(response) | ||||
|         defer.returnValue(json.loads(body)) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def post_json(self, destination, path, data={}, long_retries=False, | ||||
|                   timeout=None): | ||||
|                   timeout=None, ignore_backoff=False): | ||||
|         """ Sends the specifed json data using POST | ||||
| 
 | ||||
|         Args: | ||||
| @ -314,11 +356,15 @@ class MatrixFederationHttpClient(object): | ||||
|                 retry for a short or long time. | ||||
|             timeout(int): How long to try (in ms) the destination for before | ||||
|                 giving up. None indicates no timeout. | ||||
| 
 | ||||
|             ignore_backoff (bool): true to ignore the historical backoff data and | ||||
|                 try the request anyway. | ||||
|         Returns: | ||||
|             Deferred: Succeeds when we get a 2xx HTTP response. The result | ||||
|             will be the decoded JSON body. On a 4xx or 5xx error response a | ||||
|             CodeMessageException is raised. | ||||
| 
 | ||||
|             Fails with ``NotRetryingDestination`` if we are not yet ready | ||||
|             to retry this server. | ||||
|         """ | ||||
| 
 | ||||
|         def body_callback(method, url_bytes, headers_dict): | ||||
| @ -327,27 +373,29 @@ class MatrixFederationHttpClient(object): | ||||
|             ) | ||||
|             return _JsonProducer(data) | ||||
| 
 | ||||
|         response = yield self._create_request( | ||||
|             destination.encode("ascii"), | ||||
|         response = yield self._request( | ||||
|             destination, | ||||
|             "POST", | ||||
|             path.encode("ascii"), | ||||
|             path, | ||||
|             body_callback=body_callback, | ||||
|             headers_dict={"Content-Type": ["application/json"]}, | ||||
|             long_retries=long_retries, | ||||
|             timeout=timeout, | ||||
|             ignore_backoff=ignore_backoff, | ||||
|         ) | ||||
| 
 | ||||
|         if 200 <= response.code < 300: | ||||
|             # We need to update the transactions table to say it was sent? | ||||
|             check_content_type_is_json(response.headers) | ||||
| 
 | ||||
|         body = yield preserve_context_over_fn(readBody, response) | ||||
|         with logcontext.PreserveLoggingContext(): | ||||
|             body = yield readBody(response) | ||||
| 
 | ||||
|         defer.returnValue(json.loads(body)) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_json(self, destination, path, args={}, retry_on_dns_fail=True, | ||||
|                  timeout=None): | ||||
|                  timeout=None, ignore_backoff=False): | ||||
|         """ GETs some json from the given host homeserver and path | ||||
| 
 | ||||
|         Args: | ||||
| @ -359,11 +407,16 @@ class MatrixFederationHttpClient(object): | ||||
|             timeout (int): How long to try (in ms) the destination for before | ||||
|                 giving up. None indicates no timeout and that the request will | ||||
|                 be retried. | ||||
|             ignore_backoff (bool): true to ignore the historical backoff data | ||||
|                 and try the request anyway. | ||||
|         Returns: | ||||
|             Deferred: Succeeds when we get *any* HTTP response. | ||||
| 
 | ||||
|             The result of the deferred is a tuple of `(code, response)`, | ||||
|             where `response` is a dict representing the decoded JSON body. | ||||
| 
 | ||||
|             Fails with ``NotRetryingDestination`` if we are not yet ready | ||||
|             to retry this server. | ||||
|         """ | ||||
|         logger.debug("get_json args: %s", args) | ||||
| 
 | ||||
| @ -380,36 +433,47 @@ class MatrixFederationHttpClient(object): | ||||
|             self.sign_request(destination, method, url_bytes, headers_dict) | ||||
|             return None | ||||
| 
 | ||||
|         response = yield self._create_request( | ||||
|             destination.encode("ascii"), | ||||
|         response = yield self._request( | ||||
|             destination, | ||||
|             "GET", | ||||
|             path.encode("ascii"), | ||||
|             path, | ||||
|             query_bytes=query_bytes, | ||||
|             body_callback=body_callback, | ||||
|             retry_on_dns_fail=retry_on_dns_fail, | ||||
|             timeout=timeout, | ||||
|             ignore_backoff=ignore_backoff, | ||||
|         ) | ||||
| 
 | ||||
|         if 200 <= response.code < 300: | ||||
|             # We need to update the transactions table to say it was sent? | ||||
|             check_content_type_is_json(response.headers) | ||||
| 
 | ||||
|         body = yield preserve_context_over_fn(readBody, response) | ||||
|         with logcontext.PreserveLoggingContext(): | ||||
|             body = yield readBody(response) | ||||
| 
 | ||||
|         defer.returnValue(json.loads(body)) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_file(self, destination, path, output_stream, args={}, | ||||
|                  retry_on_dns_fail=True, max_size=None): | ||||
|                  retry_on_dns_fail=True, max_size=None, | ||||
|                  ignore_backoff=False): | ||||
|         """GETs a file from a given homeserver | ||||
|         Args: | ||||
|             destination (str): The remote server to send the HTTP request to. | ||||
|             path (str): The HTTP path to GET. | ||||
|             output_stream (file): File to write the response body to. | ||||
|             args (dict): Optional dictionary used to create the query string. | ||||
|             ignore_backoff (bool): true to ignore the historical backoff data | ||||
|                 and try the request anyway. | ||||
|         Returns: | ||||
|             A (int,dict) tuple of the file length and a dict of the response | ||||
|             headers. | ||||
|             Deferred: resolves with an (int,dict) tuple of the file length and | ||||
|             a dict of the response headers. | ||||
| 
 | ||||
|             Fails with ``HTTPRequestException`` if we get an HTTP response code | ||||
|             >= 300 | ||||
| 
 | ||||
|             Fails with ``NotRetryingDestination`` if we are not yet ready | ||||
|             to retry this server. | ||||
|         """ | ||||
| 
 | ||||
|         encoded_args = {} | ||||
| @ -419,28 +483,29 @@ class MatrixFederationHttpClient(object): | ||||
|             encoded_args[k] = [v.encode("UTF-8") for v in vs] | ||||
| 
 | ||||
|         query_bytes = urllib.urlencode(encoded_args, True) | ||||
|         logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail) | ||||
|         logger.debug("Query bytes: %s Retry DNS: %s", query_bytes, retry_on_dns_fail) | ||||
| 
 | ||||
|         def body_callback(method, url_bytes, headers_dict): | ||||
|             self.sign_request(destination, method, url_bytes, headers_dict) | ||||
|             return None | ||||
| 
 | ||||
|         response = yield self._create_request( | ||||
|             destination.encode("ascii"), | ||||
|         response = yield self._request( | ||||
|             destination, | ||||
|             "GET", | ||||
|             path.encode("ascii"), | ||||
|             path, | ||||
|             query_bytes=query_bytes, | ||||
|             body_callback=body_callback, | ||||
|             retry_on_dns_fail=retry_on_dns_fail | ||||
|             retry_on_dns_fail=retry_on_dns_fail, | ||||
|             ignore_backoff=ignore_backoff, | ||||
|         ) | ||||
| 
 | ||||
|         headers = dict(response.headers.getAllRawHeaders()) | ||||
| 
 | ||||
|         try: | ||||
|             length = yield preserve_context_over_fn( | ||||
|                 _readBodyToFile, | ||||
|                 response, output_stream, max_size | ||||
|             ) | ||||
|             with logcontext.PreserveLoggingContext(): | ||||
|                 length = yield _readBodyToFile( | ||||
|                     response, output_stream, max_size | ||||
|                 ) | ||||
|         except: | ||||
|             logger.exception("Failed to download body") | ||||
|             raise | ||||
|  | ||||
| @ -192,6 +192,16 @@ def parse_json_object_from_request(request): | ||||
|     return content | ||||
| 
 | ||||
| 
 | ||||
| def assert_params_in_request(body, required): | ||||
|     absent = [] | ||||
|     for k in required: | ||||
|         if k not in body: | ||||
|             absent.append(k) | ||||
| 
 | ||||
|     if len(absent) > 0: | ||||
|         raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) | ||||
| 
 | ||||
| 
 | ||||
| class RestServlet(object): | ||||
| 
 | ||||
|     """ A Synapse REST Servlet. | ||||
|  | ||||
| @ -16,6 +16,7 @@ | ||||
| from twisted.internet import defer | ||||
| from synapse.api.constants import EventTypes, Membership | ||||
| from synapse.api.errors import AuthError | ||||
| from synapse.handlers.presence import format_user_presence_state | ||||
| 
 | ||||
| from synapse.util import DeferredTimedOutError | ||||
| from synapse.util.logutils import log_function | ||||
| @ -37,6 +38,10 @@ metrics = synapse.metrics.get_metrics_for(__name__) | ||||
| 
 | ||||
| notified_events_counter = metrics.register_counter("notified_events") | ||||
| 
 | ||||
| users_woken_by_stream_counter = metrics.register_counter( | ||||
|     "users_woken_by_stream", labels=["stream"] | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| # TODO(paul): Should be shared somewhere | ||||
| def count(func, l): | ||||
| @ -73,6 +78,13 @@ class _NotifierUserStream(object): | ||||
|         self.user_id = user_id | ||||
|         self.rooms = set(rooms) | ||||
|         self.current_token = current_token | ||||
| 
 | ||||
|         # The last token for which we should wake up any streams that have a | ||||
|         # token that comes before it. This gets updated everytime we get poked. | ||||
|         # We start it at the current token since if we get any streams | ||||
|         # that have a token from before we have no idea whether they should be | ||||
|         # woken up or not, so lets just wake them up. | ||||
|         self.last_notified_token = current_token | ||||
|         self.last_notified_ms = time_now_ms | ||||
| 
 | ||||
|         with PreserveLoggingContext(): | ||||
| @ -89,9 +101,12 @@ class _NotifierUserStream(object): | ||||
|         self.current_token = self.current_token.copy_and_advance( | ||||
|             stream_key, stream_id | ||||
|         ) | ||||
|         self.last_notified_token = self.current_token | ||||
|         self.last_notified_ms = time_now_ms | ||||
|         noify_deferred = self.notify_deferred | ||||
| 
 | ||||
|         users_woken_by_stream_counter.inc(stream_key) | ||||
| 
 | ||||
|         with PreserveLoggingContext(): | ||||
|             self.notify_deferred = ObservableDeferred(defer.Deferred()) | ||||
|             noify_deferred.callback(self.current_token) | ||||
| @ -113,8 +128,14 @@ class _NotifierUserStream(object): | ||||
|     def new_listener(self, token): | ||||
|         """Returns a deferred that is resolved when there is a new token | ||||
|         greater than the given token. | ||||
| 
 | ||||
|         Args: | ||||
|             token: The token from which we are streaming from, i.e. we shouldn't | ||||
|                 notify for things that happened before this. | ||||
|         """ | ||||
|         if self.current_token.is_after(token): | ||||
|         # Immediately wake up stream if something has already since happened | ||||
|         # since their last token. | ||||
|         if self.last_notified_token.is_after(token): | ||||
|             return _NotificationListener(defer.succeed(self.current_token)) | ||||
|         else: | ||||
|             return _NotificationListener(self.notify_deferred.observe()) | ||||
| @ -283,8 +304,7 @@ class Notifier(object): | ||||
|         if user_stream is None: | ||||
|             current_token = yield self.event_sources.get_current_token() | ||||
|             if room_ids is None: | ||||
|                 rooms = yield self.store.get_rooms_for_user(user_id) | ||||
|                 room_ids = [room.room_id for room in rooms] | ||||
|                 room_ids = yield self.store.get_rooms_for_user(user_id) | ||||
|             user_stream = _NotifierUserStream( | ||||
|                 user_id=user_id, | ||||
|                 rooms=room_ids, | ||||
| @ -294,40 +314,44 @@ class Notifier(object): | ||||
|             self._register_with_keys(user_stream) | ||||
| 
 | ||||
|         result = None | ||||
|         prev_token = from_token | ||||
|         if timeout: | ||||
|             end_time = self.clock.time_msec() + timeout | ||||
| 
 | ||||
|             prev_token = from_token | ||||
|             while not result: | ||||
|                 try: | ||||
|                     current_token = user_stream.current_token | ||||
| 
 | ||||
|                     result = yield callback(prev_token, current_token) | ||||
|                     if result: | ||||
|                         break | ||||
| 
 | ||||
|                     now = self.clock.time_msec() | ||||
|                     if end_time <= now: | ||||
|                         break | ||||
| 
 | ||||
|                     # Now we wait for the _NotifierUserStream to be told there | ||||
|                     # is a new token. | ||||
|                     # We need to supply the token we supplied to callback so | ||||
|                     # that we don't miss any current_token updates. | ||||
|                     prev_token = current_token | ||||
|                     listener = user_stream.new_listener(prev_token) | ||||
|                     with PreserveLoggingContext(): | ||||
|                         yield self.clock.time_bound_deferred( | ||||
|                             listener.deferred, | ||||
|                             time_out=(end_time - now) / 1000. | ||||
|                         ) | ||||
| 
 | ||||
|                     current_token = user_stream.current_token | ||||
| 
 | ||||
|                     result = yield callback(prev_token, current_token) | ||||
|                     if result: | ||||
|                         break | ||||
| 
 | ||||
|                     # Update the prev_token to the current_token since nothing | ||||
|                     # has happened between the old prev_token and the current_token | ||||
|                     prev_token = current_token | ||||
|                 except DeferredTimedOutError: | ||||
|                     break | ||||
|                 except defer.CancelledError: | ||||
|                     break | ||||
|         else: | ||||
| 
 | ||||
|         if result is None: | ||||
|             # This happened if there was no timeout or if the timeout had | ||||
|             # already expired. | ||||
|             current_token = user_stream.current_token | ||||
|             result = yield callback(from_token, current_token) | ||||
|             result = yield callback(prev_token, current_token) | ||||
| 
 | ||||
|         defer.returnValue(result) | ||||
| 
 | ||||
| @ -388,6 +412,15 @@ class Notifier(object): | ||||
|                         new_events, | ||||
|                         is_peeking=is_peeking, | ||||
|                     ) | ||||
|                 elif name == "presence": | ||||
|                     now = self.clock.time_msec() | ||||
|                     new_events[:] = [ | ||||
|                         { | ||||
|                             "type": "m.presence", | ||||
|                             "content": format_user_presence_state(event, now), | ||||
|                         } | ||||
|                         for event in new_events | ||||
|                     ] | ||||
| 
 | ||||
|                 events.extend(new_events) | ||||
|                 end_token = end_token.copy_and_replace(keyname, new_key) | ||||
| @ -420,8 +453,7 @@ class Notifier(object): | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _get_room_ids(self, user, explicit_room_id): | ||||
|         joined_rooms = yield self.store.get_rooms_for_user(user.to_string()) | ||||
|         joined_room_ids = map(lambda r: r.room_id, joined_rooms) | ||||
|         joined_room_ids = yield self.store.get_rooms_for_user(user.to_string()) | ||||
|         if explicit_room_id: | ||||
|             if explicit_room_id in joined_room_ids: | ||||
|                 defer.returnValue(([explicit_room_id], True)) | ||||
|  | ||||
| @ -139,7 +139,7 @@ class Mailer(object): | ||||
| 
 | ||||
|         @defer.inlineCallbacks | ||||
|         def _fetch_room_state(room_id): | ||||
|             room_state = yield self.state_handler.get_current_state_ids(room_id) | ||||
|             room_state = yield self.store.get_current_state_ids(room_id) | ||||
|             state_by_room[room_id] = room_state | ||||
| 
 | ||||
|         # Run at most 3 of these at once: sync does 10 at a time but email | ||||
|  | ||||
| @ -17,6 +17,7 @@ import logging | ||||
| import re | ||||
| 
 | ||||
| from synapse.types import UserID | ||||
| from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache | ||||
| from synapse.util.caches.lrucache import LruCache | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| @ -125,6 +126,11 @@ class PushRuleEvaluatorForEvent(object): | ||||
|         return self._value_cache.get(dotted_key, None) | ||||
| 
 | ||||
| 
 | ||||
| # Caches (glob, word_boundary) -> regex for push. See _glob_matches | ||||
| regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR) | ||||
| register_cache("regex_push_cache", regex_cache) | ||||
| 
 | ||||
| 
 | ||||
| def _glob_matches(glob, value, word_boundary=False): | ||||
|     """Tests if value matches glob. | ||||
| 
 | ||||
| @ -137,46 +143,63 @@ def _glob_matches(glob, value, word_boundary=False): | ||||
|     Returns: | ||||
|         bool | ||||
|     """ | ||||
| 
 | ||||
|     try: | ||||
|         if IS_GLOB.search(glob): | ||||
|             r = re.escape(glob) | ||||
| 
 | ||||
|             r = r.replace(r'\*', '.*?') | ||||
|             r = r.replace(r'\?', '.') | ||||
| 
 | ||||
|             # handle [abc], [a-z] and [!a-z] style ranges. | ||||
|             r = GLOB_REGEX.sub( | ||||
|                 lambda x: ( | ||||
|                     '[%s%s]' % ( | ||||
|                         x.group(1) and '^' or '', | ||||
|                         x.group(2).replace(r'\\\-', '-') | ||||
|                     ) | ||||
|                 ), | ||||
|                 r, | ||||
|             ) | ||||
|             if word_boundary: | ||||
|                 r = r"\b%s\b" % (r,) | ||||
|                 r = _compile_regex(r) | ||||
| 
 | ||||
|                 return r.search(value) | ||||
|             else: | ||||
|                 r = r + "$" | ||||
|                 r = _compile_regex(r) | ||||
| 
 | ||||
|                 return r.match(value) | ||||
|         elif word_boundary: | ||||
|             r = re.escape(glob) | ||||
|             r = r"\b%s\b" % (r,) | ||||
|             r = _compile_regex(r) | ||||
| 
 | ||||
|             return r.search(value) | ||||
|         else: | ||||
|             return value.lower() == glob.lower() | ||||
|         r = regex_cache.get((glob, word_boundary), None) | ||||
|         if not r: | ||||
|             r = _glob_to_re(glob, word_boundary) | ||||
|             regex_cache[(glob, word_boundary)] = r | ||||
|         return r.search(value) | ||||
|     except re.error: | ||||
|         logger.warn("Failed to parse glob to regex: %r", glob) | ||||
|         return False | ||||
| 
 | ||||
| 
 | ||||
| def _glob_to_re(glob, word_boundary): | ||||
|     """Generates regex for a given glob. | ||||
| 
 | ||||
|     Args: | ||||
|         glob (string) | ||||
|         word_boundary (bool): Whether to match against word boundaries or entire | ||||
|             string. Defaults to False. | ||||
| 
 | ||||
|     Returns: | ||||
|         regex object | ||||
|     """ | ||||
|     if IS_GLOB.search(glob): | ||||
|         r = re.escape(glob) | ||||
| 
 | ||||
|         r = r.replace(r'\*', '.*?') | ||||
|         r = r.replace(r'\?', '.') | ||||
| 
 | ||||
|         # handle [abc], [a-z] and [!a-z] style ranges. | ||||
|         r = GLOB_REGEX.sub( | ||||
|             lambda x: ( | ||||
|                 '[%s%s]' % ( | ||||
|                     x.group(1) and '^' or '', | ||||
|                     x.group(2).replace(r'\\\-', '-') | ||||
|                 ) | ||||
|             ), | ||||
|             r, | ||||
|         ) | ||||
|         if word_boundary: | ||||
|             r = r"\b%s\b" % (r,) | ||||
| 
 | ||||
|             return re.compile(r, flags=re.IGNORECASE) | ||||
|         else: | ||||
|             r = "^" + r + "$" | ||||
| 
 | ||||
|             return re.compile(r, flags=re.IGNORECASE) | ||||
|     elif word_boundary: | ||||
|         r = re.escape(glob) | ||||
|         r = r"\b%s\b" % (r,) | ||||
| 
 | ||||
|         return re.compile(r, flags=re.IGNORECASE) | ||||
|     else: | ||||
|         r = "^" + re.escape(glob) + "$" | ||||
|         return re.compile(r, flags=re.IGNORECASE) | ||||
| 
 | ||||
| 
 | ||||
| def _flatten_dict(d, prefix=[], result={}): | ||||
|     for key, value in d.items(): | ||||
|         if isinstance(value, basestring): | ||||
| @ -185,16 +208,3 @@ def _flatten_dict(d, prefix=[], result={}): | ||||
|             _flatten_dict(value, prefix=(prefix + [key]), result=result) | ||||
| 
 | ||||
|     return result | ||||
| 
 | ||||
| 
 | ||||
| regex_cache = LruCache(5000) | ||||
| 
 | ||||
| 
 | ||||
| def _compile_regex(regex_str): | ||||
|     r = regex_cache.get(regex_str, None) | ||||
|     if r: | ||||
|         return r | ||||
| 
 | ||||
|     r = re.compile(regex_str, flags=re.IGNORECASE) | ||||
|     regex_cache[regex_str] = r | ||||
|     return r | ||||
|  | ||||
| @ -33,13 +33,13 @@ def get_badge_count(store, user_id): | ||||
| 
 | ||||
|     badge = len(invites) | ||||
| 
 | ||||
|     for r in joins: | ||||
|         if r.room_id in my_receipts_by_room: | ||||
|             last_unread_event_id = my_receipts_by_room[r.room_id] | ||||
|     for room_id in joins: | ||||
|         if room_id in my_receipts_by_room: | ||||
|             last_unread_event_id = my_receipts_by_room[room_id] | ||||
| 
 | ||||
|             notifs = yield ( | ||||
|                 store.get_unread_event_push_actions_by_room_for_user( | ||||
|                     r.room_id, user_id, last_unread_event_id | ||||
|                     room_id, user_id, last_unread_event_id | ||||
|                 ) | ||||
|             ) | ||||
|             # return one badge count per conversation, as count per | ||||
|  | ||||
| @ -1,4 +1,5 @@ | ||||
| # Copyright 2015, 2016 OpenMarket Ltd | ||||
| # Copyright 2017 Vector Creations Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| @ -18,6 +19,7 @@ from distutils.version import LooseVersion | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| REQUIREMENTS = { | ||||
|     "jsonschema>=2.5.1": ["jsonschema>=2.5.1"], | ||||
|     "frozendict>=0.4": ["frozendict"], | ||||
|     "unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"], | ||||
|     "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"], | ||||
| @ -37,6 +39,7 @@ REQUIREMENTS = { | ||||
|     "pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"], | ||||
|     "pymacaroons-pynacl": ["pymacaroons"], | ||||
|     "msgpack-python>=0.3.0": ["msgpack"], | ||||
|     "phonenumbers>=8.2.0": ["phonenumbers"], | ||||
| } | ||||
| CONDITIONAL_REQUIREMENTS = { | ||||
|     "web_client": { | ||||
|  | ||||
| @ -283,12 +283,12 @@ class ReplicationResource(Resource): | ||||
| 
 | ||||
|             if request_events != upto_events_token: | ||||
|                 writer.write_header_and_rows("events", res.new_forward_events, ( | ||||
|                     "position", "internal", "json", "state_group" | ||||
|                     "position", "event_id", "room_id", "type", "state_key", | ||||
|                 ), position=upto_events_token) | ||||
| 
 | ||||
|             if request_backfill != upto_backfill_token: | ||||
|                 writer.write_header_and_rows("backfill", res.new_backfill_events, ( | ||||
|                     "position", "internal", "json", "state_group", | ||||
|                     "position", "event_id", "room_id", "type", "state_key", "redacts", | ||||
|                 ), position=upto_backfill_token) | ||||
| 
 | ||||
|             writer.write_header_and_rows( | ||||
|  | ||||
| @ -27,4 +27,9 @@ class SlavedIdTracker(object): | ||||
|         self._current = (max if self.step > 0 else min)(self._current, new_id) | ||||
| 
 | ||||
|     def get_current_token(self): | ||||
|         """ | ||||
| 
 | ||||
|         Returns: | ||||
|             int | ||||
|         """ | ||||
|         return self._current | ||||
|  | ||||
| @ -16,7 +16,6 @@ from ._base import BaseSlavedStore | ||||
| from ._slaved_id_tracker import SlavedIdTracker | ||||
| 
 | ||||
| from synapse.api.constants import EventTypes | ||||
| from synapse.events import FrozenEvent | ||||
| from synapse.storage import DataStore | ||||
| from synapse.storage.roommember import RoomMemberStore | ||||
| from synapse.storage.event_federation import EventFederationStore | ||||
| @ -25,7 +24,6 @@ from synapse.storage.state import StateStore | ||||
| from synapse.storage.stream import StreamStore | ||||
| from synapse.util.caches.stream_change_cache import StreamChangeCache | ||||
| 
 | ||||
| import ujson as json | ||||
| import logging | ||||
| 
 | ||||
| 
 | ||||
| @ -109,6 +107,10 @@ class SlavedEventStore(BaseSlavedStore): | ||||
|     get_recent_event_ids_for_room = ( | ||||
|         StreamStore.__dict__["get_recent_event_ids_for_room"] | ||||
|     ) | ||||
|     get_current_state_ids = ( | ||||
|         StateStore.__dict__["get_current_state_ids"] | ||||
|     ) | ||||
|     has_room_changed_since = DataStore.has_room_changed_since.__func__ | ||||
| 
 | ||||
|     get_unread_push_actions_for_user_in_range_for_http = ( | ||||
|         DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__ | ||||
| @ -165,7 +167,6 @@ class SlavedEventStore(BaseSlavedStore): | ||||
|     _get_rooms_for_user_where_membership_is_txn = ( | ||||
|         DataStore._get_rooms_for_user_where_membership_is_txn.__func__ | ||||
|     ) | ||||
|     _get_members_rows_txn = DataStore._get_members_rows_txn.__func__ | ||||
|     _get_state_for_groups = DataStore._get_state_for_groups.__func__ | ||||
|     _get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__ | ||||
|     _get_events_around_txn = DataStore._get_events_around_txn.__func__ | ||||
| @ -238,46 +239,32 @@ class SlavedEventStore(BaseSlavedStore): | ||||
|         return super(SlavedEventStore, self).process_replication(result) | ||||
| 
 | ||||
|     def _process_replication_row(self, row, backfilled): | ||||
|         internal = json.loads(row[1]) | ||||
|         event_json = json.loads(row[2]) | ||||
|         event = FrozenEvent(event_json, internal_metadata_dict=internal) | ||||
|         stream_ordering = row[0] if not backfilled else -row[0] | ||||
|         self.invalidate_caches_for_event( | ||||
|             event, backfilled, | ||||
|             stream_ordering, row[1], row[2], row[3], row[4], row[5], | ||||
|             backfilled=backfilled, | ||||
|         ) | ||||
| 
 | ||||
|     def invalidate_caches_for_event(self, event, backfilled): | ||||
|         self._invalidate_get_event_cache(event.event_id) | ||||
|     def invalidate_caches_for_event(self, stream_ordering, event_id, room_id, | ||||
|                                     etype, state_key, redacts, backfilled): | ||||
|         self._invalidate_get_event_cache(event_id) | ||||
| 
 | ||||
|         self.get_latest_event_ids_in_room.invalidate((event.room_id,)) | ||||
|         self.get_latest_event_ids_in_room.invalidate((room_id,)) | ||||
| 
 | ||||
|         self.get_unread_event_push_actions_by_room_for_user.invalidate_many( | ||||
|             (event.room_id,) | ||||
|             (room_id,) | ||||
|         ) | ||||
| 
 | ||||
|         if not backfilled: | ||||
|             self._events_stream_cache.entity_has_changed( | ||||
|                 event.room_id, event.internal_metadata.stream_ordering | ||||
|                 room_id, stream_ordering | ||||
|             ) | ||||
| 
 | ||||
|         # self.get_unread_event_push_actions_by_room_for_user.invalidate_many( | ||||
|         #     (event.room_id,) | ||||
|         # ) | ||||
|         if redacts: | ||||
|             self._invalidate_get_event_cache(redacts) | ||||
| 
 | ||||
|         if event.type == EventTypes.Redaction: | ||||
|             self._invalidate_get_event_cache(event.redacts) | ||||
| 
 | ||||
|         if event.type == EventTypes.Member: | ||||
|         if etype == EventTypes.Member: | ||||
|             self._membership_stream_cache.entity_has_changed( | ||||
|                 event.state_key, event.internal_metadata.stream_ordering | ||||
|                 state_key, stream_ordering | ||||
|             ) | ||||
|             self.get_invited_rooms_for_user.invalidate((event.state_key,)) | ||||
| 
 | ||||
|         if not event.is_state(): | ||||
|             return | ||||
| 
 | ||||
|         if backfilled: | ||||
|             return | ||||
| 
 | ||||
|         if (not event.internal_metadata.is_invite_from_remote() | ||||
|                 and event.internal_metadata.is_outlier()): | ||||
|             return | ||||
|             self.get_invited_rooms_for_user.invalidate((state_key,)) | ||||
|  | ||||
| @ -57,5 +57,6 @@ class SlavedPresenceStore(BaseSlavedStore): | ||||
|                 self.presence_stream_cache.entity_has_changed( | ||||
|                     user_id, position | ||||
|                 ) | ||||
|                 self._get_presence_for_user.invalidate((user_id,)) | ||||
| 
 | ||||
|         return super(SlavedPresenceStore, self).process_replication(result) | ||||
|  | ||||
| @ -19,6 +19,7 @@ from synapse.api.errors import SynapseError, LoginError, Codes | ||||
| from synapse.types import UserID | ||||
| from synapse.http.server import finish_request | ||||
| from synapse.http.servlet import parse_json_object_from_request | ||||
| from synapse.util.msisdn import phone_number_to_msisdn | ||||
| 
 | ||||
| from .base import ClientV1RestServlet, client_path_patterns | ||||
| 
 | ||||
| @ -33,10 +34,55 @@ from saml2.client import Saml2Client | ||||
| 
 | ||||
| import xml.etree.ElementTree as ET | ||||
| 
 | ||||
| from twisted.web.client import PartialDownloadError | ||||
| 
 | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| def login_submission_legacy_convert(submission): | ||||
|     """ | ||||
|     If the input login submission is an old style object | ||||
|     (ie. with top-level user / medium / address) convert it | ||||
|     to a typed object. | ||||
|     """ | ||||
|     if "user" in submission: | ||||
|         submission["identifier"] = { | ||||
|             "type": "m.id.user", | ||||
|             "user": submission["user"], | ||||
|         } | ||||
|         del submission["user"] | ||||
| 
 | ||||
|     if "medium" in submission and "address" in submission: | ||||
|         submission["identifier"] = { | ||||
|             "type": "m.id.thirdparty", | ||||
|             "medium": submission["medium"], | ||||
|             "address": submission["address"], | ||||
|         } | ||||
|         del submission["medium"] | ||||
|         del submission["address"] | ||||
| 
 | ||||
| 
 | ||||
| def login_id_thirdparty_from_phone(identifier): | ||||
|     """ | ||||
|     Convert a phone login identifier type to a generic threepid identifier | ||||
|     Args: | ||||
|         identifier(dict): Login identifier dict of type 'm.id.phone' | ||||
| 
 | ||||
|     Returns: Login identifier dict of type 'm.id.threepid' | ||||
|     """ | ||||
|     if "country" not in identifier or "number" not in identifier: | ||||
|         raise SynapseError(400, "Invalid phone-type identifier") | ||||
| 
 | ||||
|     msisdn = phone_number_to_msisdn(identifier["country"], identifier["number"]) | ||||
| 
 | ||||
|     return { | ||||
|         "type": "m.id.thirdparty", | ||||
|         "medium": "msisdn", | ||||
|         "address": msisdn, | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| class LoginRestServlet(ClientV1RestServlet): | ||||
|     PATTERNS = client_path_patterns("/login$") | ||||
|     PASS_TYPE = "m.login.password" | ||||
| @ -117,20 +163,52 @@ class LoginRestServlet(ClientV1RestServlet): | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def do_password_login(self, login_submission): | ||||
|         if 'medium' in login_submission and 'address' in login_submission: | ||||
|             address = login_submission['address'] | ||||
|             if login_submission['medium'] == 'email': | ||||
|         if "password" not in login_submission: | ||||
|             raise SynapseError(400, "Missing parameter: password") | ||||
| 
 | ||||
|         login_submission_legacy_convert(login_submission) | ||||
| 
 | ||||
|         if "identifier" not in login_submission: | ||||
|             raise SynapseError(400, "Missing param: identifier") | ||||
| 
 | ||||
|         identifier = login_submission["identifier"] | ||||
|         if "type" not in identifier: | ||||
|             raise SynapseError(400, "Login identifier has no type") | ||||
| 
 | ||||
|         # convert phone type identifiers to generic threepids | ||||
|         if identifier["type"] == "m.id.phone": | ||||
|             identifier = login_id_thirdparty_from_phone(identifier) | ||||
| 
 | ||||
|         # convert threepid identifiers to user IDs | ||||
|         if identifier["type"] == "m.id.thirdparty": | ||||
|             if 'medium' not in identifier or 'address' not in identifier: | ||||
|                 raise SynapseError(400, "Invalid thirdparty identifier") | ||||
| 
 | ||||
|             address = identifier['address'] | ||||
|             if identifier['medium'] == 'email': | ||||
|                 # For emails, transform the address to lowercase. | ||||
|                 # We store all email addreses as lowercase in the DB. | ||||
|                 # (See add_threepid in synapse/handlers/auth.py) | ||||
|                 address = address.lower() | ||||
|             user_id = yield self.hs.get_datastore().get_user_id_by_threepid( | ||||
|                 login_submission['medium'], address | ||||
|                 identifier['medium'], address | ||||
|             ) | ||||
|             if not user_id: | ||||
|                 raise LoginError(403, "", errcode=Codes.FORBIDDEN) | ||||
|         else: | ||||
|             user_id = login_submission['user'] | ||||
| 
 | ||||
|             identifier = { | ||||
|                 "type": "m.id.user", | ||||
|                 "user": user_id, | ||||
|             } | ||||
| 
 | ||||
|         # by this point, the identifier should be an m.id.user: if it's anything | ||||
|         # else, we haven't understood it. | ||||
|         if identifier["type"] != "m.id.user": | ||||
|             raise SynapseError(400, "Unknown login identifier type") | ||||
|         if "user" not in identifier: | ||||
|             raise SynapseError(400, "User identifier is missing 'user' key") | ||||
| 
 | ||||
|         user_id = identifier["user"] | ||||
| 
 | ||||
|         if not user_id.startswith('@'): | ||||
|             user_id = UserID.create( | ||||
| @ -341,7 +419,12 @@ class CasTicketServlet(ClientV1RestServlet): | ||||
|             "ticket": request.args["ticket"], | ||||
|             "service": self.cas_service_url | ||||
|         } | ||||
|         body = yield http_client.get_raw(uri, args) | ||||
|         try: | ||||
|             body = yield http_client.get_raw(uri, args) | ||||
|         except PartialDownloadError as pde: | ||||
|             # Twisted raises this error if the connection is closed, | ||||
|             # even if that's being used old-http style to signal end-of-data | ||||
|             body = pde.response | ||||
|         result = yield self.handle_cas_response(request, body, client_redirect_url) | ||||
|         defer.returnValue(result) | ||||
| 
 | ||||
|  | ||||
| @ -19,6 +19,7 @@ from twisted.internet import defer | ||||
| 
 | ||||
| from synapse.api.errors import SynapseError, AuthError | ||||
| from synapse.types import UserID | ||||
| from synapse.handlers.presence import format_user_presence_state | ||||
| from synapse.http.servlet import parse_json_object_from_request | ||||
| from .base import ClientV1RestServlet, client_path_patterns | ||||
| 
 | ||||
| @ -33,6 +34,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): | ||||
|     def __init__(self, hs): | ||||
|         super(PresenceStatusRestServlet, self).__init__(hs) | ||||
|         self.presence_handler = hs.get_presence_handler() | ||||
|         self.clock = hs.get_clock() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_GET(self, request, user_id): | ||||
| @ -48,6 +50,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): | ||||
|                 raise AuthError(403, "You are not allowed to see their presence.") | ||||
| 
 | ||||
|         state = yield self.presence_handler.get_state(target_user=user) | ||||
|         state = format_user_presence_state(state, self.clock.time_msec()) | ||||
| 
 | ||||
|         defer.returnValue((200, state)) | ||||
| 
 | ||||
|  | ||||
| @ -748,8 +748,7 @@ class JoinedRoomsRestServlet(ClientV1RestServlet): | ||||
|     def on_GET(self, request): | ||||
|         requester = yield self.auth.get_user_by_req(request, allow_guest=True) | ||||
| 
 | ||||
|         rooms = yield self.store.get_rooms_for_user(requester.user.to_string()) | ||||
|         room_ids = set(r.room_id for r in rooms)  # Ensure they're unique. | ||||
|         room_ids = yield self.store.get_rooms_for_user(requester.user.to_string()) | ||||
|         defer.returnValue((200, {"joined_rooms": list(room_ids)})) | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -1,5 +1,6 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2015, 2016 OpenMarket Ltd | ||||
| # Copyright 2017 Vector Creations Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| @ -17,8 +18,11 @@ from twisted.internet import defer | ||||
| 
 | ||||
| from synapse.api.constants import LoginType | ||||
| from synapse.api.errors import LoginError, SynapseError, Codes | ||||
| from synapse.http.servlet import RestServlet, parse_json_object_from_request | ||||
| from synapse.http.servlet import ( | ||||
|     RestServlet, parse_json_object_from_request, assert_params_in_request | ||||
| ) | ||||
| from synapse.util.async import run_on_reactor | ||||
| from synapse.util.msisdn import phone_number_to_msisdn | ||||
| 
 | ||||
| from ._base import client_v2_patterns | ||||
| 
 | ||||
| @ -28,11 +32,11 @@ import logging | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| class PasswordRequestTokenRestServlet(RestServlet): | ||||
| class EmailPasswordRequestTokenRestServlet(RestServlet): | ||||
|     PATTERNS = client_v2_patterns("/account/password/email/requestToken$") | ||||
| 
 | ||||
|     def __init__(self, hs): | ||||
|         super(PasswordRequestTokenRestServlet, self).__init__() | ||||
|         super(EmailPasswordRequestTokenRestServlet, self).__init__() | ||||
|         self.hs = hs | ||||
|         self.identity_handler = hs.get_handlers().identity_handler | ||||
| 
 | ||||
| @ -40,14 +44,9 @@ class PasswordRequestTokenRestServlet(RestServlet): | ||||
|     def on_POST(self, request): | ||||
|         body = parse_json_object_from_request(request) | ||||
| 
 | ||||
|         required = ['id_server', 'client_secret', 'email', 'send_attempt'] | ||||
|         absent = [] | ||||
|         for k in required: | ||||
|             if k not in body: | ||||
|                 absent.append(k) | ||||
| 
 | ||||
|         if absent: | ||||
|             raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) | ||||
|         assert_params_in_request(body, [ | ||||
|             'id_server', 'client_secret', 'email', 'send_attempt' | ||||
|         ]) | ||||
| 
 | ||||
|         existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( | ||||
|             'email', body['email'] | ||||
| @ -60,6 +59,37 @@ class PasswordRequestTokenRestServlet(RestServlet): | ||||
|         defer.returnValue((200, ret)) | ||||
| 
 | ||||
| 
 | ||||
| class MsisdnPasswordRequestTokenRestServlet(RestServlet): | ||||
|     PATTERNS = client_v2_patterns("/account/password/msisdn/requestToken$") | ||||
| 
 | ||||
|     def __init__(self, hs): | ||||
|         super(MsisdnPasswordRequestTokenRestServlet, self).__init__() | ||||
|         self.hs = hs | ||||
|         self.datastore = self.hs.get_datastore() | ||||
|         self.identity_handler = hs.get_handlers().identity_handler | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_POST(self, request): | ||||
|         body = parse_json_object_from_request(request) | ||||
| 
 | ||||
|         assert_params_in_request(body, [ | ||||
|             'id_server', 'client_secret', | ||||
|             'country', 'phone_number', 'send_attempt', | ||||
|         ]) | ||||
| 
 | ||||
|         msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) | ||||
| 
 | ||||
|         existingUid = yield self.datastore.get_user_id_by_threepid( | ||||
|             'msisdn', msisdn | ||||
|         ) | ||||
| 
 | ||||
|         if existingUid is None: | ||||
|             raise SynapseError(400, "MSISDN not found", Codes.THREEPID_NOT_FOUND) | ||||
| 
 | ||||
|         ret = yield self.identity_handler.requestMsisdnToken(**body) | ||||
|         defer.returnValue((200, ret)) | ||||
| 
 | ||||
| 
 | ||||
| class PasswordRestServlet(RestServlet): | ||||
|     PATTERNS = client_v2_patterns("/account/password$") | ||||
| 
 | ||||
| @ -68,6 +98,7 @@ class PasswordRestServlet(RestServlet): | ||||
|         self.hs = hs | ||||
|         self.auth = hs.get_auth() | ||||
|         self.auth_handler = hs.get_auth_handler() | ||||
|         self.datastore = self.hs.get_datastore() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_POST(self, request): | ||||
| @ -77,7 +108,8 @@ class PasswordRestServlet(RestServlet): | ||||
| 
 | ||||
|         authed, result, params, _ = yield self.auth_handler.check_auth([ | ||||
|             [LoginType.PASSWORD], | ||||
|             [LoginType.EMAIL_IDENTITY] | ||||
|             [LoginType.EMAIL_IDENTITY], | ||||
|             [LoginType.MSISDN], | ||||
|         ], body, self.hs.get_ip_from_request(request)) | ||||
| 
 | ||||
|         if not authed: | ||||
| @ -102,7 +134,7 @@ class PasswordRestServlet(RestServlet): | ||||
|                 # (See add_threepid in synapse/handlers/auth.py) | ||||
|                 threepid['address'] = threepid['address'].lower() | ||||
|             # if using email, we must know about the email they're authing with! | ||||
|             threepid_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( | ||||
|             threepid_user_id = yield self.datastore.get_user_id_by_threepid( | ||||
|                 threepid['medium'], threepid['address'] | ||||
|             ) | ||||
|             if not threepid_user_id: | ||||
| @ -169,13 +201,14 @@ class DeactivateAccountRestServlet(RestServlet): | ||||
|         defer.returnValue((200, {})) | ||||
| 
 | ||||
| 
 | ||||
| class ThreepidRequestTokenRestServlet(RestServlet): | ||||
| class EmailThreepidRequestTokenRestServlet(RestServlet): | ||||
|     PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$") | ||||
| 
 | ||||
|     def __init__(self, hs): | ||||
|         self.hs = hs | ||||
|         super(ThreepidRequestTokenRestServlet, self).__init__() | ||||
|         super(EmailThreepidRequestTokenRestServlet, self).__init__() | ||||
|         self.identity_handler = hs.get_handlers().identity_handler | ||||
|         self.datastore = self.hs.get_datastore() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_POST(self, request): | ||||
| @ -190,7 +223,7 @@ class ThreepidRequestTokenRestServlet(RestServlet): | ||||
|         if absent: | ||||
|             raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) | ||||
| 
 | ||||
|         existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( | ||||
|         existingUid = yield self.datastore.get_user_id_by_threepid( | ||||
|             'email', body['email'] | ||||
|         ) | ||||
| 
 | ||||
| @ -201,6 +234,44 @@ class ThreepidRequestTokenRestServlet(RestServlet): | ||||
|         defer.returnValue((200, ret)) | ||||
| 
 | ||||
| 
 | ||||
| class MsisdnThreepidRequestTokenRestServlet(RestServlet): | ||||
|     PATTERNS = client_v2_patterns("/account/3pid/msisdn/requestToken$") | ||||
| 
 | ||||
|     def __init__(self, hs): | ||||
|         self.hs = hs | ||||
|         super(MsisdnThreepidRequestTokenRestServlet, self).__init__() | ||||
|         self.identity_handler = hs.get_handlers().identity_handler | ||||
|         self.datastore = self.hs.get_datastore() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_POST(self, request): | ||||
|         body = parse_json_object_from_request(request) | ||||
| 
 | ||||
|         required = [ | ||||
|             'id_server', 'client_secret', | ||||
|             'country', 'phone_number', 'send_attempt', | ||||
|         ] | ||||
|         absent = [] | ||||
|         for k in required: | ||||
|             if k not in body: | ||||
|                 absent.append(k) | ||||
| 
 | ||||
|         if absent: | ||||
|             raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) | ||||
| 
 | ||||
|         msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) | ||||
| 
 | ||||
|         existingUid = yield self.datastore.get_user_id_by_threepid( | ||||
|             'msisdn', msisdn | ||||
|         ) | ||||
| 
 | ||||
|         if existingUid is not None: | ||||
|             raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) | ||||
| 
 | ||||
|         ret = yield self.identity_handler.requestMsisdnToken(**body) | ||||
|         defer.returnValue((200, ret)) | ||||
| 
 | ||||
| 
 | ||||
| class ThreepidRestServlet(RestServlet): | ||||
|     PATTERNS = client_v2_patterns("/account/3pid$") | ||||
| 
 | ||||
| @ -210,6 +281,7 @@ class ThreepidRestServlet(RestServlet): | ||||
|         self.identity_handler = hs.get_handlers().identity_handler | ||||
|         self.auth = hs.get_auth() | ||||
|         self.auth_handler = hs.get_auth_handler() | ||||
|         self.datastore = self.hs.get_datastore() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_GET(self, request): | ||||
| @ -217,7 +289,7 @@ class ThreepidRestServlet(RestServlet): | ||||
| 
 | ||||
|         requester = yield self.auth.get_user_by_req(request) | ||||
| 
 | ||||
|         threepids = yield self.hs.get_datastore().user_get_threepids( | ||||
|         threepids = yield self.datastore.user_get_threepids( | ||||
|             requester.user.to_string() | ||||
|         ) | ||||
| 
 | ||||
| @ -258,7 +330,7 @@ class ThreepidRestServlet(RestServlet): | ||||
| 
 | ||||
|         if 'bind' in body and body['bind']: | ||||
|             logger.debug( | ||||
|                 "Binding emails %s to %s", | ||||
|                 "Binding threepid %s to %s", | ||||
|                 threepid, user_id | ||||
|             ) | ||||
|             yield self.identity_handler.bind_threepid( | ||||
| @ -302,9 +374,11 @@ class ThreepidDeleteRestServlet(RestServlet): | ||||
| 
 | ||||
| 
 | ||||
| def register_servlets(hs, http_server): | ||||
|     PasswordRequestTokenRestServlet(hs).register(http_server) | ||||
|     EmailPasswordRequestTokenRestServlet(hs).register(http_server) | ||||
|     MsisdnPasswordRequestTokenRestServlet(hs).register(http_server) | ||||
|     PasswordRestServlet(hs).register(http_server) | ||||
|     DeactivateAccountRestServlet(hs).register(http_server) | ||||
|     ThreepidRequestTokenRestServlet(hs).register(http_server) | ||||
|     EmailThreepidRequestTokenRestServlet(hs).register(http_server) | ||||
|     MsisdnThreepidRequestTokenRestServlet(hs).register(http_server) | ||||
|     ThreepidRestServlet(hs).register(http_server) | ||||
|     ThreepidDeleteRestServlet(hs).register(http_server) | ||||
|  | ||||
| @ -46,6 +46,52 @@ class DevicesRestServlet(servlet.RestServlet): | ||||
|         defer.returnValue((200, {"devices": devices})) | ||||
| 
 | ||||
| 
 | ||||
| class DeleteDevicesRestServlet(servlet.RestServlet): | ||||
|     """ | ||||
|     API for bulk deletion of devices. Accepts a JSON object with a devices | ||||
|     key which lists the device_ids to delete. Requires user interactive auth. | ||||
|     """ | ||||
|     PATTERNS = client_v2_patterns("/delete_devices", releases=[], v2_alpha=False) | ||||
| 
 | ||||
|     def __init__(self, hs): | ||||
|         super(DeleteDevicesRestServlet, self).__init__() | ||||
|         self.hs = hs | ||||
|         self.auth = hs.get_auth() | ||||
|         self.device_handler = hs.get_device_handler() | ||||
|         self.auth_handler = hs.get_auth_handler() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_POST(self, request): | ||||
|         try: | ||||
|             body = servlet.parse_json_object_from_request(request) | ||||
|         except errors.SynapseError as e: | ||||
|             if e.errcode == errors.Codes.NOT_JSON: | ||||
|                 # deal with older clients which didn't pass a J*DELETESON dict | ||||
|                 # the same as those that pass an empty dict | ||||
|                 body = {} | ||||
|             else: | ||||
|                 raise e | ||||
| 
 | ||||
|         if 'devices' not in body: | ||||
|             raise errors.SynapseError( | ||||
|                 400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM | ||||
|             ) | ||||
| 
 | ||||
|         authed, result, params, _ = yield self.auth_handler.check_auth([ | ||||
|             [constants.LoginType.PASSWORD], | ||||
|         ], body, self.hs.get_ip_from_request(request)) | ||||
| 
 | ||||
|         if not authed: | ||||
|             defer.returnValue((401, result)) | ||||
| 
 | ||||
|         requester = yield self.auth.get_user_by_req(request) | ||||
|         yield self.device_handler.delete_devices( | ||||
|             requester.user.to_string(), | ||||
|             body['devices'], | ||||
|         ) | ||||
|         defer.returnValue((200, {})) | ||||
| 
 | ||||
| 
 | ||||
| class DeviceRestServlet(servlet.RestServlet): | ||||
|     PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", | ||||
|                                   releases=[], v2_alpha=False) | ||||
| @ -111,5 +157,6 @@ class DeviceRestServlet(servlet.RestServlet): | ||||
| 
 | ||||
| 
 | ||||
| def register_servlets(hs, http_server): | ||||
|     DeleteDevicesRestServlet(hs).register(http_server) | ||||
|     DevicesRestServlet(hs).register(http_server) | ||||
|     DeviceRestServlet(hs).register(http_server) | ||||
|  | ||||
| @ -1,5 +1,6 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2015 - 2016 OpenMarket Ltd | ||||
| # Copyright 2017 Vector Creations Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| @ -19,7 +20,10 @@ import synapse | ||||
| from synapse.api.auth import get_access_token_from_request, has_access_token | ||||
| from synapse.api.constants import LoginType | ||||
| from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError | ||||
| from synapse.http.servlet import RestServlet, parse_json_object_from_request | ||||
| from synapse.http.servlet import ( | ||||
|     RestServlet, parse_json_object_from_request, assert_params_in_request | ||||
| ) | ||||
| from synapse.util.msisdn import phone_number_to_msisdn | ||||
| 
 | ||||
| from ._base import client_v2_patterns | ||||
| 
 | ||||
| @ -43,7 +47,7 @@ else: | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| class RegisterRequestTokenRestServlet(RestServlet): | ||||
| class EmailRegisterRequestTokenRestServlet(RestServlet): | ||||
|     PATTERNS = client_v2_patterns("/register/email/requestToken$") | ||||
| 
 | ||||
|     def __init__(self, hs): | ||||
| @ -51,7 +55,7 @@ class RegisterRequestTokenRestServlet(RestServlet): | ||||
|         Args: | ||||
|             hs (synapse.server.HomeServer): server | ||||
|         """ | ||||
|         super(RegisterRequestTokenRestServlet, self).__init__() | ||||
|         super(EmailRegisterRequestTokenRestServlet, self).__init__() | ||||
|         self.hs = hs | ||||
|         self.identity_handler = hs.get_handlers().identity_handler | ||||
| 
 | ||||
| @ -59,14 +63,9 @@ class RegisterRequestTokenRestServlet(RestServlet): | ||||
|     def on_POST(self, request): | ||||
|         body = parse_json_object_from_request(request) | ||||
| 
 | ||||
|         required = ['id_server', 'client_secret', 'email', 'send_attempt'] | ||||
|         absent = [] | ||||
|         for k in required: | ||||
|             if k not in body: | ||||
|                 absent.append(k) | ||||
| 
 | ||||
|         if len(absent) > 0: | ||||
|             raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) | ||||
|         assert_params_in_request(body, [ | ||||
|             'id_server', 'client_secret', 'email', 'send_attempt' | ||||
|         ]) | ||||
| 
 | ||||
|         existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( | ||||
|             'email', body['email'] | ||||
| @ -79,6 +78,43 @@ class RegisterRequestTokenRestServlet(RestServlet): | ||||
|         defer.returnValue((200, ret)) | ||||
| 
 | ||||
| 
 | ||||
| class MsisdnRegisterRequestTokenRestServlet(RestServlet): | ||||
|     PATTERNS = client_v2_patterns("/register/msisdn/requestToken$") | ||||
| 
 | ||||
|     def __init__(self, hs): | ||||
|         """ | ||||
|         Args: | ||||
|             hs (synapse.server.HomeServer): server | ||||
|         """ | ||||
|         super(MsisdnRegisterRequestTokenRestServlet, self).__init__() | ||||
|         self.hs = hs | ||||
|         self.identity_handler = hs.get_handlers().identity_handler | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_POST(self, request): | ||||
|         body = parse_json_object_from_request(request) | ||||
| 
 | ||||
|         assert_params_in_request(body, [ | ||||
|             'id_server', 'client_secret', | ||||
|             'country', 'phone_number', | ||||
|             'send_attempt', | ||||
|         ]) | ||||
| 
 | ||||
|         msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) | ||||
| 
 | ||||
|         existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( | ||||
|             'msisdn', msisdn | ||||
|         ) | ||||
| 
 | ||||
|         if existingUid is not None: | ||||
|             raise SynapseError( | ||||
|                 400, "Phone number is already in use", Codes.THREEPID_IN_USE | ||||
|             ) | ||||
| 
 | ||||
|         ret = yield self.identity_handler.requestMsisdnToken(**body) | ||||
|         defer.returnValue((200, ret)) | ||||
| 
 | ||||
| 
 | ||||
| class RegisterRestServlet(RestServlet): | ||||
|     PATTERNS = client_v2_patterns("/register$") | ||||
| 
 | ||||
| @ -200,16 +236,37 @@ class RegisterRestServlet(RestServlet): | ||||
|                 assigned_user_id=registered_user_id, | ||||
|             ) | ||||
| 
 | ||||
|         # Only give msisdn flows if the x_show_msisdn flag is given: | ||||
|         # this is a hack to work around the fact that clients were shipped | ||||
|         # that use fallback registration if they see any flows that they don't | ||||
|         # recognise, which means we break registration for these clients if we | ||||
|         # advertise msisdn flows. Once usage of Riot iOS <=0.3.9 and Riot | ||||
|         # Android <=0.6.9 have fallen below an acceptable threshold, this | ||||
|         # parameter should go away and we should always advertise msisdn flows. | ||||
|         show_msisdn = False | ||||
|         if 'x_show_msisdn' in body and body['x_show_msisdn']: | ||||
|             show_msisdn = True | ||||
| 
 | ||||
|         if self.hs.config.enable_registration_captcha: | ||||
|             flows = [ | ||||
|                 [LoginType.RECAPTCHA], | ||||
|                 [LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA] | ||||
|                 [LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA], | ||||
|             ] | ||||
|             if show_msisdn: | ||||
|                 flows.extend([ | ||||
|                     [LoginType.MSISDN, LoginType.RECAPTCHA], | ||||
|                     [LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA], | ||||
|                 ]) | ||||
|         else: | ||||
|             flows = [ | ||||
|                 [LoginType.DUMMY], | ||||
|                 [LoginType.EMAIL_IDENTITY] | ||||
|                 [LoginType.EMAIL_IDENTITY], | ||||
|             ] | ||||
|             if show_msisdn: | ||||
|                 flows.extend([ | ||||
|                     [LoginType.MSISDN], | ||||
|                     [LoginType.MSISDN, LoginType.EMAIL_IDENTITY], | ||||
|                 ]) | ||||
| 
 | ||||
|         authed, auth_result, params, session_id = yield self.auth_handler.check_auth( | ||||
|             flows, body, self.hs.get_ip_from_request(request) | ||||
| @ -224,8 +281,9 @@ class RegisterRestServlet(RestServlet): | ||||
|                 "Already registered user ID %r for this session", | ||||
|                 registered_user_id | ||||
|             ) | ||||
|             # don't re-register the email address | ||||
|             # don't re-register the threepids | ||||
|             add_email = False | ||||
|             add_msisdn = False | ||||
|         else: | ||||
|             # NB: This may be from the auth handler and NOT from the POST | ||||
|             if 'password' not in params: | ||||
| @ -250,6 +308,7 @@ class RegisterRestServlet(RestServlet): | ||||
|             ) | ||||
| 
 | ||||
|             add_email = True | ||||
|             add_msisdn = True | ||||
| 
 | ||||
|         return_dict = yield self._create_registration_details( | ||||
|             registered_user_id, params | ||||
| @ -262,6 +321,13 @@ class RegisterRestServlet(RestServlet): | ||||
|                 params.get("bind_email") | ||||
|             ) | ||||
| 
 | ||||
|         if add_msisdn and auth_result and LoginType.MSISDN in auth_result: | ||||
|             threepid = auth_result[LoginType.MSISDN] | ||||
|             yield self._register_msisdn_threepid( | ||||
|                 registered_user_id, threepid, return_dict["access_token"], | ||||
|                 params.get("bind_msisdn") | ||||
|             ) | ||||
| 
 | ||||
|         defer.returnValue((200, return_dict)) | ||||
| 
 | ||||
|     def on_OPTIONS(self, _): | ||||
| @ -323,8 +389,9 @@ class RegisterRestServlet(RestServlet): | ||||
|         """ | ||||
|         reqd = ('medium', 'address', 'validated_at') | ||||
|         if any(x not in threepid for x in reqd): | ||||
|             # This will only happen if the ID server returns a malformed response | ||||
|             logger.info("Can't add incomplete 3pid") | ||||
|             defer.returnValue() | ||||
|             return | ||||
| 
 | ||||
|         yield self.auth_handler.add_threepid( | ||||
|             user_id, | ||||
| @ -371,6 +438,43 @@ class RegisterRestServlet(RestServlet): | ||||
|         else: | ||||
|             logger.info("bind_email not specified: not binding email") | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _register_msisdn_threepid(self, user_id, threepid, token, bind_msisdn): | ||||
|         """Add a phone number as a 3pid identifier | ||||
| 
 | ||||
|         Also optionally binds msisdn to the given user_id on the identity server | ||||
| 
 | ||||
|         Args: | ||||
|             user_id (str): id of user | ||||
|             threepid (object): m.login.msisdn auth response | ||||
|             token (str): access_token for the user | ||||
|             bind_email (bool): true if the client requested the email to be | ||||
|                 bound at the identity server | ||||
|         Returns: | ||||
|             defer.Deferred: | ||||
|         """ | ||||
|         reqd = ('medium', 'address', 'validated_at') | ||||
|         if any(x not in threepid for x in reqd): | ||||
|             # This will only happen if the ID server returns a malformed response | ||||
|             logger.info("Can't add incomplete 3pid") | ||||
|             defer.returnValue() | ||||
| 
 | ||||
|         yield self.auth_handler.add_threepid( | ||||
|             user_id, | ||||
|             threepid['medium'], | ||||
|             threepid['address'], | ||||
|             threepid['validated_at'], | ||||
|         ) | ||||
| 
 | ||||
|         if bind_msisdn: | ||||
|             logger.info("bind_msisdn specified: binding") | ||||
|             logger.debug("Binding msisdn %s to %s", threepid, user_id) | ||||
|             yield self.identity_handler.bind_threepid( | ||||
|                 threepid['threepid_creds'], user_id | ||||
|             ) | ||||
|         else: | ||||
|             logger.info("bind_msisdn not specified: not binding msisdn") | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _create_registration_details(self, user_id, params): | ||||
|         """Complete registration of newly-registered user | ||||
| @ -433,7 +537,7 @@ class RegisterRestServlet(RestServlet): | ||||
|         # we have nowhere to store it. | ||||
|         device_id = synapse.api.auth.GUEST_DEVICE_ID | ||||
|         initial_display_name = params.get("initial_device_display_name") | ||||
|         self.device_handler.check_device_registered( | ||||
|         yield self.device_handler.check_device_registered( | ||||
|             user_id, device_id, initial_display_name | ||||
|         ) | ||||
| 
 | ||||
| @ -449,5 +553,6 @@ class RegisterRestServlet(RestServlet): | ||||
| 
 | ||||
| 
 | ||||
| def register_servlets(hs, http_server): | ||||
|     RegisterRequestTokenRestServlet(hs).register(http_server) | ||||
|     EmailRegisterRequestTokenRestServlet(hs).register(http_server) | ||||
|     MsisdnRegisterRequestTokenRestServlet(hs).register(http_server) | ||||
|     RegisterRestServlet(hs).register(http_server) | ||||
|  | ||||
| @ -18,6 +18,7 @@ from twisted.internet import defer | ||||
| from synapse.http.servlet import ( | ||||
|     RestServlet, parse_string, parse_integer, parse_boolean | ||||
| ) | ||||
| from synapse.handlers.presence import format_user_presence_state | ||||
| from synapse.handlers.sync import SyncConfig | ||||
| from synapse.types import StreamToken | ||||
| from synapse.events.utils import ( | ||||
| @ -28,7 +29,6 @@ from synapse.api.errors import SynapseError | ||||
| from synapse.api.constants import PresenceState | ||||
| from ._base import client_v2_patterns | ||||
| 
 | ||||
| import copy | ||||
| import itertools | ||||
| import logging | ||||
| 
 | ||||
| @ -194,12 +194,18 @@ class SyncRestServlet(RestServlet): | ||||
|         defer.returnValue((200, response_content)) | ||||
| 
 | ||||
|     def encode_presence(self, events, time_now): | ||||
|         formatted = [] | ||||
|         for event in events: | ||||
|             event = copy.deepcopy(event) | ||||
|             event['sender'] = event['content'].pop('user_id') | ||||
|             formatted.append(event) | ||||
|         return {"events": formatted} | ||||
|         return { | ||||
|             "events": [ | ||||
|                 { | ||||
|                     "type": "m.presence", | ||||
|                     "sender": event.user_id, | ||||
|                     "content": format_user_presence_state( | ||||
|                         event, time_now, include_user_id=False | ||||
|                     ), | ||||
|                 } | ||||
|                 for event in events | ||||
|             ] | ||||
|         } | ||||
| 
 | ||||
|     def encode_joined(self, rooms, time_now, token_id, event_fields): | ||||
|         """ | ||||
|  | ||||
| @ -12,6 +12,7 @@ | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| import synapse.http.servlet | ||||
| 
 | ||||
| from ._base import parse_media_id, respond_with_file, respond_404 | ||||
| from twisted.web.resource import Resource | ||||
| @ -81,6 +82,17 @@ class DownloadResource(Resource): | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _respond_remote_file(self, request, server_name, media_id, name): | ||||
|         # don't forward requests for remote media if allow_remote is false | ||||
|         allow_remote = synapse.http.servlet.parse_boolean( | ||||
|             request, "allow_remote", default=True) | ||||
|         if not allow_remote: | ||||
|             logger.info( | ||||
|                 "Rejecting request for remote media %s/%s due to allow_remote", | ||||
|                 server_name, media_id, | ||||
|             ) | ||||
|             respond_404(request) | ||||
|             return | ||||
| 
 | ||||
|         media_info = yield self.media_repo.get_remote_media(server_name, media_id) | ||||
| 
 | ||||
|         media_type = media_info["media_type"] | ||||
|  | ||||
| @ -13,22 +13,23 @@ | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from twisted.internet import defer, threads | ||||
| import twisted.internet.error | ||||
| import twisted.web.http | ||||
| from twisted.web.resource import Resource | ||||
| 
 | ||||
| from .upload_resource import UploadResource | ||||
| from .download_resource import DownloadResource | ||||
| from .thumbnail_resource import ThumbnailResource | ||||
| from .identicon_resource import IdenticonResource | ||||
| from .preview_url_resource import PreviewUrlResource | ||||
| from .filepath import MediaFilePaths | ||||
| 
 | ||||
| from twisted.web.resource import Resource | ||||
| 
 | ||||
| from .thumbnailer import Thumbnailer | ||||
| 
 | ||||
| from synapse.http.matrixfederationclient import MatrixFederationHttpClient | ||||
| from synapse.util.stringutils import random_string | ||||
| from synapse.api.errors import SynapseError | ||||
| 
 | ||||
| from twisted.internet import defer, threads | ||||
| from synapse.api.errors import SynapseError, HttpResponseException, \ | ||||
|     NotFoundError | ||||
| 
 | ||||
| from synapse.util.async import Linearizer | ||||
| from synapse.util.stringutils import is_ascii | ||||
| @ -157,11 +158,34 @@ class MediaRepository(object): | ||||
|                 try: | ||||
|                     length, headers = yield self.client.get_file( | ||||
|                         server_name, request_path, output_stream=f, | ||||
|                         max_size=self.max_upload_size, | ||||
|                         max_size=self.max_upload_size, args={ | ||||
|                             # tell the remote server to 404 if it doesn't | ||||
|                             # recognise the server_name, to make sure we don't | ||||
|                             # end up with a routing loop. | ||||
|                             "allow_remote": "false", | ||||
|                         } | ||||
|                     ) | ||||
|                 except Exception as e: | ||||
|                     logger.warn("Failed to fetch remoted media %r", e) | ||||
|                     raise SynapseError(502, "Failed to fetch remoted media") | ||||
|                 except twisted.internet.error.DNSLookupError as e: | ||||
|                     logger.warn("HTTP error fetching remote media %s/%s: %r", | ||||
|                                 server_name, media_id, e) | ||||
|                     raise NotFoundError() | ||||
| 
 | ||||
|                 except HttpResponseException as e: | ||||
|                     logger.warn("HTTP error fetching remote media %s/%s: %s", | ||||
|                                 server_name, media_id, e.response) | ||||
|                     if e.code == twisted.web.http.NOT_FOUND: | ||||
|                         raise SynapseError.from_http_response_exception(e) | ||||
|                     raise SynapseError(502, "Failed to fetch remote media") | ||||
| 
 | ||||
|                 except SynapseError: | ||||
|                     logger.exception("Failed to fetch remote media %s/%s", | ||||
|                                      server_name, media_id) | ||||
|                     raise | ||||
| 
 | ||||
|                 except Exception: | ||||
|                     logger.exception("Failed to fetch remote media %s/%s", | ||||
|                                      server_name, media_id) | ||||
|                     raise SynapseError(502, "Failed to fetch remote media") | ||||
| 
 | ||||
|             media_type = headers["Content-Type"][0] | ||||
|             time_now_ms = self.clock.time_msec() | ||||
|  | ||||
| @ -177,17 +177,12 @@ class StateHandler(object): | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def compute_event_context(self, event, old_state=None): | ||||
|         """ Fills out the context with the `current state` of the graph. The | ||||
|         `current state` here is defined to be the state of the event graph | ||||
|         just before the event - i.e. it never includes `event` | ||||
| 
 | ||||
|         If `event` has `auth_events` then this will also fill out the | ||||
|         `auth_events` field on `context` from the `current_state`. | ||||
|         """Build an EventContext structure for the event. | ||||
| 
 | ||||
|         Args: | ||||
|             event (EventBase) | ||||
|             event (synapse.events.EventBase): | ||||
|         Returns: | ||||
|             an EventContext | ||||
|             synapse.events.snapshot.EventContext: | ||||
|         """ | ||||
|         context = EventContext() | ||||
| 
 | ||||
| @ -200,11 +195,11 @@ class StateHandler(object): | ||||
|                     (s.type, s.state_key): s.event_id for s in old_state | ||||
|                 } | ||||
|                 if event.is_state(): | ||||
|                     context.current_state_events = dict(context.prev_state_ids) | ||||
|                     context.current_state_ids = dict(context.prev_state_ids) | ||||
|                     key = (event.type, event.state_key) | ||||
|                     context.current_state_events[key] = event.event_id | ||||
|                     context.current_state_ids[key] = event.event_id | ||||
|                 else: | ||||
|                     context.current_state_events = context.prev_state_ids | ||||
|                     context.current_state_ids = context.prev_state_ids | ||||
|             else: | ||||
|                 context.current_state_ids = {} | ||||
|                 context.prev_state_ids = {} | ||||
|  | ||||
| @ -73,6 +73,9 @@ class LoggingTransaction(object): | ||||
|     def __setattr__(self, name, value): | ||||
|         setattr(self.txn, name, value) | ||||
| 
 | ||||
|     def __iter__(self): | ||||
|         return self.txn.__iter__() | ||||
| 
 | ||||
|     def execute(self, sql, *args): | ||||
|         self._do_execute(self.txn.execute, sql, *args) | ||||
| 
 | ||||
| @ -132,7 +135,7 @@ class PerformanceCounters(object): | ||||
| 
 | ||||
|     def interval(self, interval_duration, limit=3): | ||||
|         counters = [] | ||||
|         for name, (count, cum_time) in self.current_counters.items(): | ||||
|         for name, (count, cum_time) in self.current_counters.iteritems(): | ||||
|             prev_count, prev_time = self.previous_counters.get(name, (0, 0)) | ||||
|             counters.append(( | ||||
|                 (cum_time - prev_time) / interval_duration, | ||||
| @ -357,7 +360,7 @@ class SQLBaseStore(object): | ||||
|         """ | ||||
|         col_headers = list(intern(column[0]) for column in cursor.description) | ||||
|         results = list( | ||||
|             dict(zip(col_headers, row)) for row in cursor.fetchall() | ||||
|             dict(zip(col_headers, row)) for row in cursor | ||||
|         ) | ||||
|         return results | ||||
| 
 | ||||
| @ -565,7 +568,7 @@ class SQLBaseStore(object): | ||||
|     @staticmethod | ||||
|     def _simple_select_onecol_txn(txn, table, keyvalues, retcol): | ||||
|         if keyvalues: | ||||
|             where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) | ||||
|             where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys()) | ||||
|         else: | ||||
|             where = "" | ||||
| 
 | ||||
| @ -579,7 +582,7 @@ class SQLBaseStore(object): | ||||
| 
 | ||||
|         txn.execute(sql, keyvalues.values()) | ||||
| 
 | ||||
|         return [r[0] for r in txn.fetchall()] | ||||
|         return [r[0] for r in txn] | ||||
| 
 | ||||
|     def _simple_select_onecol(self, table, keyvalues, retcol, | ||||
|                               desc="_simple_select_onecol"): | ||||
| @ -712,7 +715,7 @@ class SQLBaseStore(object): | ||||
|         ) | ||||
|         values.extend(iterable) | ||||
| 
 | ||||
|         for key, value in keyvalues.items(): | ||||
|         for key, value in keyvalues.iteritems(): | ||||
|             clauses.append("%s = ?" % (key,)) | ||||
|             values.append(value) | ||||
| 
 | ||||
| @ -753,7 +756,7 @@ class SQLBaseStore(object): | ||||
|     @staticmethod | ||||
|     def _simple_update_one_txn(txn, table, keyvalues, updatevalues): | ||||
|         if keyvalues: | ||||
|             where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) | ||||
|             where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys()) | ||||
|         else: | ||||
|             where = "" | ||||
| 
 | ||||
| @ -840,6 +843,47 @@ class SQLBaseStore(object): | ||||
| 
 | ||||
|         return txn.execute(sql, keyvalues.values()) | ||||
| 
 | ||||
|     def _simple_delete_many(self, table, column, iterable, keyvalues, desc): | ||||
|         return self.runInteraction( | ||||
|             desc, self._simple_delete_many_txn, table, column, iterable, keyvalues | ||||
|         ) | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _simple_delete_many_txn(txn, table, column, iterable, keyvalues): | ||||
|         """Executes a DELETE query on the named table. | ||||
| 
 | ||||
|         Filters rows by if value of `column` is in `iterable`. | ||||
| 
 | ||||
|         Args: | ||||
|             txn : Transaction object | ||||
|             table : string giving the table name | ||||
|             column : column name to test for inclusion against `iterable` | ||||
|             iterable : list | ||||
|             keyvalues : dict of column names and values to select the rows with | ||||
|         """ | ||||
|         if not iterable: | ||||
|             return | ||||
| 
 | ||||
|         sql = "DELETE FROM %s" % table | ||||
| 
 | ||||
|         clauses = [] | ||||
|         values = [] | ||||
|         clauses.append( | ||||
|             "%s IN (%s)" % (column, ",".join("?" for _ in iterable)) | ||||
|         ) | ||||
|         values.extend(iterable) | ||||
| 
 | ||||
|         for key, value in keyvalues.iteritems(): | ||||
|             clauses.append("%s = ?" % (key,)) | ||||
|             values.append(value) | ||||
| 
 | ||||
|         if clauses: | ||||
|             sql = "%s WHERE %s" % ( | ||||
|                 sql, | ||||
|                 " AND ".join(clauses), | ||||
|             ) | ||||
|         return txn.execute(sql, values) | ||||
| 
 | ||||
|     def _get_cache_dict(self, db_conn, table, entity_column, stream_column, | ||||
|                         max_value, limit=100000): | ||||
|         # Fetch a mapping of room_id -> max stream position for "recent" rooms. | ||||
| @ -860,16 +904,16 @@ class SQLBaseStore(object): | ||||
| 
 | ||||
|         txn = db_conn.cursor() | ||||
|         txn.execute(sql, (int(max_value),)) | ||||
|         rows = txn.fetchall() | ||||
|         txn.close() | ||||
| 
 | ||||
|         cache = { | ||||
|             row[0]: int(row[1]) | ||||
|             for row in rows | ||||
|             for row in txn | ||||
|         } | ||||
| 
 | ||||
|         txn.close() | ||||
| 
 | ||||
|         if cache: | ||||
|             min_val = min(cache.values()) | ||||
|             min_val = min(cache.itervalues()) | ||||
|         else: | ||||
|             min_val = max_value | ||||
| 
 | ||||
|  | ||||
| @ -182,7 +182,7 @@ class AccountDataStore(SQLBaseStore): | ||||
|             txn.execute(sql, (user_id, stream_id)) | ||||
| 
 | ||||
|             global_account_data = { | ||||
|                 row[0]: json.loads(row[1]) for row in txn.fetchall() | ||||
|                 row[0]: json.loads(row[1]) for row in txn | ||||
|             } | ||||
| 
 | ||||
|             sql = ( | ||||
| @ -193,7 +193,7 @@ class AccountDataStore(SQLBaseStore): | ||||
|             txn.execute(sql, (user_id, stream_id)) | ||||
| 
 | ||||
|             account_data_by_room = {} | ||||
|             for row in txn.fetchall(): | ||||
|             for row in txn: | ||||
|                 room_account_data = account_data_by_room.setdefault(row[0], {}) | ||||
|                 room_account_data[row[1]] = json.loads(row[2]) | ||||
| 
 | ||||
|  | ||||
| @ -12,6 +12,7 @@ | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| import synapse.util.async | ||||
| 
 | ||||
| from ._base import SQLBaseStore | ||||
| from . import engines | ||||
| @ -84,24 +85,14 @@ class BackgroundUpdateStore(SQLBaseStore): | ||||
|         self._background_update_performance = {} | ||||
|         self._background_update_queue = [] | ||||
|         self._background_update_handlers = {} | ||||
|         self._background_update_timer = None | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def start_doing_background_updates(self): | ||||
|         assert self._background_update_timer is None, \ | ||||
|             "background updates already running" | ||||
| 
 | ||||
|         logger.info("Starting background schema updates") | ||||
| 
 | ||||
|         while True: | ||||
|             sleep = defer.Deferred() | ||||
|             self._background_update_timer = self._clock.call_later( | ||||
|                 self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None | ||||
|             ) | ||||
|             try: | ||||
|                 yield sleep | ||||
|             finally: | ||||
|                 self._background_update_timer = None | ||||
|             yield synapse.util.async.sleep( | ||||
|                 self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.) | ||||
| 
 | ||||
|             try: | ||||
|                 result = yield self.do_next_background_update( | ||||
|  | ||||
| @ -178,7 +178,7 @@ class DeviceInboxStore(BackgroundUpdateStore): | ||||
|                 ) | ||||
|                 txn.execute(sql, (user_id,)) | ||||
|                 message_json = ujson.dumps(messages_by_device["*"]) | ||||
|                 for row in txn.fetchall(): | ||||
|                 for row in txn: | ||||
|                     # Add the message for all devices for this user on this | ||||
|                     # server. | ||||
|                     device = row[0] | ||||
| @ -195,7 +195,7 @@ class DeviceInboxStore(BackgroundUpdateStore): | ||||
|                 # TODO: Maybe this needs to be done in batches if there are | ||||
|                 # too many local devices for a given user. | ||||
|                 txn.execute(sql, [user_id] + devices) | ||||
|                 for row in txn.fetchall(): | ||||
|                 for row in txn: | ||||
|                     # Only insert into the local inbox if the device exists on | ||||
|                     # this server | ||||
|                     device = row[0] | ||||
| @ -251,7 +251,7 @@ class DeviceInboxStore(BackgroundUpdateStore): | ||||
|                 user_id, device_id, last_stream_id, current_stream_id, limit | ||||
|             )) | ||||
|             messages = [] | ||||
|             for row in txn.fetchall(): | ||||
|             for row in txn: | ||||
|                 stream_pos = row[0] | ||||
|                 messages.append(ujson.loads(row[1])) | ||||
|             if len(messages) < limit: | ||||
| @ -340,7 +340,7 @@ class DeviceInboxStore(BackgroundUpdateStore): | ||||
|                 " ORDER BY stream_id ASC" | ||||
|             ) | ||||
|             txn.execute(sql, (last_pos, upper_pos)) | ||||
|             rows.extend(txn.fetchall()) | ||||
|             rows.extend(txn) | ||||
| 
 | ||||
|             return rows | ||||
| 
 | ||||
| @ -357,12 +357,12 @@ class DeviceInboxStore(BackgroundUpdateStore): | ||||
|         """ | ||||
|         Args: | ||||
|             destination(str): The name of the remote server. | ||||
|             last_stream_id(int): The last position of the device message stream | ||||
|             last_stream_id(int|long): The last position of the device message stream | ||||
|                 that the server sent up to. | ||||
|             current_stream_id(int): The current position of the device | ||||
|             current_stream_id(int|long): The current position of the device | ||||
|                 message stream. | ||||
|         Returns: | ||||
|             Deferred ([dict], int): List of messages for the device and where | ||||
|             Deferred ([dict], int|long): List of messages for the device and where | ||||
|                 in the stream the messages got to. | ||||
|         """ | ||||
| 
 | ||||
| @ -384,7 +384,7 @@ class DeviceInboxStore(BackgroundUpdateStore): | ||||
|                 destination, last_stream_id, current_stream_id, limit | ||||
|             )) | ||||
|             messages = [] | ||||
|             for row in txn.fetchall(): | ||||
|             for row in txn: | ||||
|                 stream_pos = row[0] | ||||
|                 messages.append(ujson.loads(row[1])) | ||||
|             if len(messages) < limit: | ||||
|  | ||||
| @ -108,6 +108,23 @@ class DeviceStore(SQLBaseStore): | ||||
|             desc="delete_device", | ||||
|         ) | ||||
| 
 | ||||
|     def delete_devices(self, user_id, device_ids): | ||||
|         """Deletes several devices. | ||||
| 
 | ||||
|         Args: | ||||
|             user_id (str): The ID of the user which owns the devices | ||||
|             device_ids (list): The IDs of the devices to delete | ||||
|         Returns: | ||||
|             defer.Deferred | ||||
|         """ | ||||
|         return self._simple_delete_many( | ||||
|             table="devices", | ||||
|             column="device_id", | ||||
|             iterable=device_ids, | ||||
|             keyvalues={"user_id": user_id}, | ||||
|             desc="delete_devices", | ||||
|         ) | ||||
| 
 | ||||
|     def update_device(self, user_id, device_id, new_display_name=None): | ||||
|         """Update a device. | ||||
| 
 | ||||
| @ -291,7 +308,7 @@ class DeviceStore(SQLBaseStore): | ||||
|         """Get stream of updates to send to remote servers | ||||
| 
 | ||||
|         Returns: | ||||
|             (now_stream_id, [ { updates }, .. ]) | ||||
|             (int, list[dict]): current stream id and list of updates | ||||
|         """ | ||||
|         now_stream_id = self._device_list_id_gen.get_current_token() | ||||
| 
 | ||||
| @ -312,17 +329,20 @@ class DeviceStore(SQLBaseStore): | ||||
|             SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes | ||||
|             WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? | ||||
|             GROUP BY user_id, device_id | ||||
|             LIMIT 20 | ||||
|         """ | ||||
|         txn.execute( | ||||
|             sql, (destination, from_stream_id, now_stream_id, False) | ||||
|         ) | ||||
|         rows = txn.fetchall() | ||||
| 
 | ||||
|         if not rows: | ||||
|             return (now_stream_id, []) | ||||
| 
 | ||||
|         # maps (user_id, device_id) -> stream_id | ||||
|         query_map = {(r[0], r[1]): r[2] for r in rows} | ||||
|         query_map = {(r[0], r[1]): r[2] for r in txn} | ||||
|         if not query_map: | ||||
|             return (now_stream_id, []) | ||||
| 
 | ||||
|         if len(query_map) >= 20: | ||||
|             now_stream_id = max(stream_id for stream_id in query_map.itervalues()) | ||||
| 
 | ||||
|         devices = self._get_e2e_device_keys_txn( | ||||
|             txn, query_map.keys(), include_all_devices=True | ||||
|         ) | ||||
|  | ||||
| @ -14,6 +14,8 @@ | ||||
| # limitations under the License. | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from synapse.api.errors import SynapseError | ||||
| 
 | ||||
| from canonicaljson import encode_canonical_json | ||||
| import ujson as json | ||||
| 
 | ||||
| @ -120,24 +122,63 @@ class EndToEndKeyStore(SQLBaseStore): | ||||
| 
 | ||||
|         return result | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list): | ||||
|         """Insert some new one time keys for a device. | ||||
| 
 | ||||
|         Checks if any of the keys are already inserted, if they are then check | ||||
|         if they match. If they don't then we raise an error. | ||||
|         """ | ||||
| 
 | ||||
|         # First we check if we have already persisted any of the keys. | ||||
|         rows = yield self._simple_select_many_batch( | ||||
|             table="e2e_one_time_keys_json", | ||||
|             column="key_id", | ||||
|             iterable=[key_id for _, key_id, _ in key_list], | ||||
|             retcols=("algorithm", "key_id", "key_json",), | ||||
|             keyvalues={ | ||||
|                 "user_id": user_id, | ||||
|                 "device_id": device_id, | ||||
|             }, | ||||
|             desc="add_e2e_one_time_keys_check", | ||||
|         ) | ||||
| 
 | ||||
|         existing_key_map = { | ||||
|             (row["algorithm"], row["key_id"]): row["key_json"] for row in rows | ||||
|         } | ||||
| 
 | ||||
|         new_keys = []  # Keys that we need to insert | ||||
|         for algorithm, key_id, json_bytes in key_list: | ||||
|             ex_bytes = existing_key_map.get((algorithm, key_id), None) | ||||
|             if ex_bytes: | ||||
|                 if json_bytes != ex_bytes: | ||||
|                     raise SynapseError( | ||||
|                         400, "One time key with key_id %r already exists" % (key_id,) | ||||
|                     ) | ||||
|             else: | ||||
|                 new_keys.append((algorithm, key_id, json_bytes)) | ||||
| 
 | ||||
|         def _add_e2e_one_time_keys(txn): | ||||
|             for (algorithm, key_id, json_bytes) in key_list: | ||||
|                 self._simple_upsert_txn( | ||||
|                     txn, table="e2e_one_time_keys_json", | ||||
|                     keyvalues={ | ||||
|             # We are protected from race between lookup and insertion due to | ||||
|             # a unique constraint. If there is a race of two calls to | ||||
|             # `add_e2e_one_time_keys` then they'll conflict and we will only | ||||
|             # insert one set. | ||||
|             self._simple_insert_many_txn( | ||||
|                 txn, table="e2e_one_time_keys_json", | ||||
|                 values=[ | ||||
|                     { | ||||
|                         "user_id": user_id, | ||||
|                         "device_id": device_id, | ||||
|                         "algorithm": algorithm, | ||||
|                         "key_id": key_id, | ||||
|                     }, | ||||
|                     values={ | ||||
|                         "ts_added_ms": time_now, | ||||
|                         "key_json": json_bytes, | ||||
|                     } | ||||
|                 ) | ||||
|         return self.runInteraction( | ||||
|             "add_e2e_one_time_keys", _add_e2e_one_time_keys | ||||
|                     for algorithm, key_id, json_bytes in new_keys | ||||
|                 ], | ||||
|             ) | ||||
|         yield self.runInteraction( | ||||
|             "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys | ||||
|         ) | ||||
| 
 | ||||
|     def count_e2e_one_time_keys(self, user_id, device_id): | ||||
| @ -153,7 +194,7 @@ class EndToEndKeyStore(SQLBaseStore): | ||||
|             ) | ||||
|             txn.execute(sql, (user_id, device_id)) | ||||
|             result = {} | ||||
|             for algorithm, key_count in txn.fetchall(): | ||||
|             for algorithm, key_count in txn: | ||||
|                 result[algorithm] = key_count | ||||
|             return result | ||||
|         return self.runInteraction( | ||||
| @ -174,7 +215,7 @@ class EndToEndKeyStore(SQLBaseStore): | ||||
|                 user_result = result.setdefault(user_id, {}) | ||||
|                 device_result = user_result.setdefault(device_id, {}) | ||||
|                 txn.execute(sql, (user_id, device_id, algorithm)) | ||||
|                 for key_id, key_json in txn.fetchall(): | ||||
|                 for key_id, key_json in txn: | ||||
|                     device_result[algorithm + ":" + key_id] = key_json | ||||
|                     delete.append((user_id, device_id, algorithm, key_id)) | ||||
|             sql = ( | ||||
|  | ||||
| @ -74,7 +74,7 @@ class EventFederationStore(SQLBaseStore): | ||||
|                     base_sql % (",".join(["?"] * len(chunk)),), | ||||
|                     chunk | ||||
|                 ) | ||||
|                 new_front.update([r[0] for r in txn.fetchall()]) | ||||
|                 new_front.update([r[0] for r in txn]) | ||||
| 
 | ||||
|             new_front -= results | ||||
| 
 | ||||
| @ -110,7 +110,7 @@ class EventFederationStore(SQLBaseStore): | ||||
| 
 | ||||
|         txn.execute(sql, (room_id, False,)) | ||||
| 
 | ||||
|         return dict(txn.fetchall()) | ||||
|         return dict(txn) | ||||
| 
 | ||||
|     def _get_oldest_events_in_room_txn(self, txn, room_id): | ||||
|         return self._simple_select_onecol_txn( | ||||
| @ -201,19 +201,19 @@ class EventFederationStore(SQLBaseStore): | ||||
|     def _update_min_depth_for_room_txn(self, txn, room_id, depth): | ||||
|         min_depth = self._get_min_depth_interaction(txn, room_id) | ||||
| 
 | ||||
|         do_insert = depth < min_depth if min_depth else True | ||||
|         if min_depth and depth >= min_depth: | ||||
|             return | ||||
| 
 | ||||
|         if do_insert: | ||||
|             self._simple_upsert_txn( | ||||
|                 txn, | ||||
|                 table="room_depth", | ||||
|                 keyvalues={ | ||||
|                     "room_id": room_id, | ||||
|                 }, | ||||
|                 values={ | ||||
|                     "min_depth": depth, | ||||
|                 }, | ||||
|             ) | ||||
|         self._simple_upsert_txn( | ||||
|             txn, | ||||
|             table="room_depth", | ||||
|             keyvalues={ | ||||
|                 "room_id": room_id, | ||||
|             }, | ||||
|             values={ | ||||
|                 "min_depth": depth, | ||||
|             }, | ||||
|         ) | ||||
| 
 | ||||
|     def _handle_mult_prev_events(self, txn, events): | ||||
|         """ | ||||
| @ -334,8 +334,7 @@ class EventFederationStore(SQLBaseStore): | ||||
| 
 | ||||
|         def get_forward_extremeties_for_room_txn(txn): | ||||
|             txn.execute(sql, (stream_ordering, room_id)) | ||||
|             rows = txn.fetchall() | ||||
|             return [event_id for event_id, in rows] | ||||
|             return [event_id for event_id, in txn] | ||||
| 
 | ||||
|         return self.runInteraction( | ||||
|             "get_forward_extremeties_for_room", | ||||
| @ -436,7 +435,7 @@ class EventFederationStore(SQLBaseStore): | ||||
|                 (room_id, event_id, False, limit - len(event_results)) | ||||
|             ) | ||||
| 
 | ||||
|             for row in txn.fetchall(): | ||||
|             for row in txn: | ||||
|                 if row[1] not in event_results: | ||||
|                     queue.put((-row[0], row[1])) | ||||
| 
 | ||||
| @ -482,7 +481,7 @@ class EventFederationStore(SQLBaseStore): | ||||
|                     (room_id, event_id, False, limit - len(event_results)) | ||||
|                 ) | ||||
| 
 | ||||
|                 for e_id, in txn.fetchall(): | ||||
|                 for e_id, in txn: | ||||
|                     new_front.add(e_id) | ||||
| 
 | ||||
|             new_front -= earliest_events | ||||
|  | ||||
| @ -206,7 +206,7 @@ class EventPushActionsStore(SQLBaseStore): | ||||
|                 " stream_ordering >= ? AND stream_ordering <= ?" | ||||
|             ) | ||||
|             txn.execute(sql, (min_stream_ordering, max_stream_ordering)) | ||||
|             return [r[0] for r in txn.fetchall()] | ||||
|             return [r[0] for r in txn] | ||||
|         ret = yield self.runInteraction("get_push_action_users_in_range", f) | ||||
|         defer.returnValue(ret) | ||||
| 
 | ||||
|  | ||||
| @ -34,14 +34,16 @@ from canonicaljson import encode_canonical_json | ||||
| from collections import deque, namedtuple, OrderedDict | ||||
| from functools import wraps | ||||
| 
 | ||||
| import synapse | ||||
| import synapse.metrics | ||||
| 
 | ||||
| 
 | ||||
| import logging | ||||
| import math | ||||
| import ujson as json | ||||
| 
 | ||||
| # these are only included to make the type annotations work | ||||
| from synapse.events import EventBase    # noqa: F401 | ||||
| from synapse.events.snapshot import EventContext   # noqa: F401 | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| @ -82,6 +84,11 @@ class _EventPeristenceQueue(object): | ||||
| 
 | ||||
|     def add_to_queue(self, room_id, events_and_contexts, backfilled): | ||||
|         """Add events to the queue, with the given persist_event options. | ||||
| 
 | ||||
|         Args: | ||||
|             room_id (str): | ||||
|             events_and_contexts (list[(EventBase, EventContext)]): | ||||
|             backfilled (bool): | ||||
|         """ | ||||
|         queue = self._event_persist_queues.setdefault(room_id, deque()) | ||||
|         if queue: | ||||
| @ -210,14 +217,14 @@ class EventsStore(SQLBaseStore): | ||||
|             partitioned.setdefault(event.room_id, []).append((event, ctx)) | ||||
| 
 | ||||
|         deferreds = [] | ||||
|         for room_id, evs_ctxs in partitioned.items(): | ||||
|         for room_id, evs_ctxs in partitioned.iteritems(): | ||||
|             d = preserve_fn(self._event_persist_queue.add_to_queue)( | ||||
|                 room_id, evs_ctxs, | ||||
|                 backfilled=backfilled, | ||||
|             ) | ||||
|             deferreds.append(d) | ||||
| 
 | ||||
|         for room_id in partitioned.keys(): | ||||
|         for room_id in partitioned: | ||||
|             self._maybe_start_persisting(room_id) | ||||
| 
 | ||||
|         return preserve_context_over_deferred( | ||||
| @ -227,6 +234,17 @@ class EventsStore(SQLBaseStore): | ||||
|     @defer.inlineCallbacks | ||||
|     @log_function | ||||
|     def persist_event(self, event, context, backfilled=False): | ||||
|         """ | ||||
| 
 | ||||
|         Args: | ||||
|             event (EventBase): | ||||
|             context (EventContext): | ||||
|             backfilled (bool): | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred: resolves to (int, int): the stream ordering of ``event``, | ||||
|             and the stream ordering of the latest persisted event | ||||
|         """ | ||||
|         deferred = self._event_persist_queue.add_to_queue( | ||||
|             event.room_id, [(event, context)], | ||||
|             backfilled=backfilled, | ||||
| @ -253,6 +271,16 @@ class EventsStore(SQLBaseStore): | ||||
|     @defer.inlineCallbacks | ||||
|     def _persist_events(self, events_and_contexts, backfilled=False, | ||||
|                         delete_existing=False): | ||||
|         """Persist events to db | ||||
| 
 | ||||
|         Args: | ||||
|             events_and_contexts (list[(EventBase, EventContext)]): | ||||
|             backfilled (bool): | ||||
|             delete_existing (bool): | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred: resolves when the events have been persisted | ||||
|         """ | ||||
|         if not events_and_contexts: | ||||
|             return | ||||
| 
 | ||||
| @ -295,7 +323,7 @@ class EventsStore(SQLBaseStore): | ||||
|                                 (event, context) | ||||
|                             ) | ||||
| 
 | ||||
|                         for room_id, ev_ctx_rm in events_by_room.items(): | ||||
|                         for room_id, ev_ctx_rm in events_by_room.iteritems(): | ||||
|                             # Work out new extremities by recursively adding and removing | ||||
|                             # the new events. | ||||
|                             latest_event_ids = yield self.get_latest_event_ids_in_room( | ||||
| @ -400,6 +428,7 @@ class EventsStore(SQLBaseStore): | ||||
|         # Now we need to work out the different state sets for | ||||
|         # each state extremities | ||||
|         state_sets = [] | ||||
|         state_groups = set() | ||||
|         missing_event_ids = [] | ||||
|         was_updated = False | ||||
|         for event_id in new_latest_event_ids: | ||||
| @ -409,9 +438,17 @@ class EventsStore(SQLBaseStore): | ||||
|                 if event_id == ev.event_id: | ||||
|                     if ctx.current_state_ids is None: | ||||
|                         raise Exception("Unknown current state") | ||||
|                     state_sets.append(ctx.current_state_ids) | ||||
|                     if ctx.delta_ids or hasattr(ev, "state_key"): | ||||
|                         was_updated = True | ||||
| 
 | ||||
|                     # If we've already seen the state group don't bother adding | ||||
|                     # it to the state sets again | ||||
|                     if ctx.state_group not in state_groups: | ||||
|                         state_sets.append(ctx.current_state_ids) | ||||
|                         if ctx.delta_ids or hasattr(ev, "state_key"): | ||||
|                             was_updated = True | ||||
|                         if ctx.state_group: | ||||
|                             # Add this as a seen state group (if it has a state | ||||
|                             # group) | ||||
|                             state_groups.add(ctx.state_group) | ||||
|                     break | ||||
|             else: | ||||
|                 # If we couldn't find it, then we'll need to pull | ||||
| @ -425,31 +462,57 @@ class EventsStore(SQLBaseStore): | ||||
|                 missing_event_ids, | ||||
|             ) | ||||
| 
 | ||||
|             groups = set(event_to_groups.values()) | ||||
|             group_to_state = yield self._get_state_for_groups(groups) | ||||
|             groups = set(event_to_groups.itervalues()) - state_groups | ||||
| 
 | ||||
|             state_sets.extend(group_to_state.values()) | ||||
|             if groups: | ||||
|                 group_to_state = yield self._get_state_for_groups(groups) | ||||
|                 state_sets.extend(group_to_state.itervalues()) | ||||
| 
 | ||||
|         if not new_latest_event_ids: | ||||
|             current_state = {} | ||||
|         elif was_updated: | ||||
|             current_state = yield resolve_events( | ||||
|                 state_sets, | ||||
|                 state_map_factory=lambda ev_ids: self.get_events( | ||||
|                     ev_ids, get_prev_content=False, check_redacted=False, | ||||
|                 ), | ||||
|             ) | ||||
|             if len(state_sets) == 1: | ||||
|                 # If there is only one state set, then we know what the current | ||||
|                 # state is. | ||||
|                 current_state = state_sets[0] | ||||
|             else: | ||||
|                 # We work out the current state by passing the state sets to the | ||||
|                 # state resolution algorithm. It may ask for some events, including | ||||
|                 # the events we have yet to persist, so we need a slightly more | ||||
|                 # complicated event lookup function than simply looking the events | ||||
|                 # up in the db. | ||||
|                 events_map = {ev.event_id: ev for ev, _ in events_context} | ||||
| 
 | ||||
|                 @defer.inlineCallbacks | ||||
|                 def get_events(ev_ids): | ||||
|                     # We get the events by first looking at the list of events we | ||||
|                     # are trying to persist, and then fetching the rest from the DB. | ||||
|                     db = [] | ||||
|                     to_return = {} | ||||
|                     for ev_id in ev_ids: | ||||
|                         ev = events_map.get(ev_id, None) | ||||
|                         if ev: | ||||
|                             to_return[ev_id] = ev | ||||
|                         else: | ||||
|                             db.append(ev_id) | ||||
| 
 | ||||
|                     if db: | ||||
|                         evs = yield self.get_events( | ||||
|                             ev_ids, get_prev_content=False, check_redacted=False, | ||||
|                         ) | ||||
|                         to_return.update(evs) | ||||
|                     defer.returnValue(to_return) | ||||
| 
 | ||||
|                 current_state = yield resolve_events( | ||||
|                     state_sets, | ||||
|                     state_map_factory=get_events, | ||||
|                 ) | ||||
|         else: | ||||
|             return | ||||
| 
 | ||||
|         existing_state_rows = yield self._simple_select_list( | ||||
|             table="current_state_events", | ||||
|             keyvalues={"room_id": room_id}, | ||||
|             retcols=["event_id", "type", "state_key"], | ||||
|             desc="_calculate_state_delta", | ||||
|         ) | ||||
|         existing_state = yield self.get_current_state_ids(room_id) | ||||
| 
 | ||||
|         existing_events = set(row["event_id"] for row in existing_state_rows) | ||||
|         existing_events = set(existing_state.itervalues()) | ||||
|         new_events = set(ev_id for ev_id in current_state.itervalues()) | ||||
|         changed_events = existing_events ^ new_events | ||||
| 
 | ||||
| @ -457,9 +520,8 @@ class EventsStore(SQLBaseStore): | ||||
|             return | ||||
| 
 | ||||
|         to_delete = { | ||||
|             (row["type"], row["state_key"]): row["event_id"] | ||||
|             for row in existing_state_rows | ||||
|             if row["event_id"] in changed_events | ||||
|             key: ev_id for key, ev_id in existing_state.iteritems() | ||||
|             if ev_id in changed_events | ||||
|         } | ||||
|         events_to_insert = (new_events - existing_events) | ||||
|         to_insert = { | ||||
| @ -535,11 +597,91 @@ class EventsStore(SQLBaseStore): | ||||
|         and the rejections table. Things reading from those table will need to check | ||||
|         whether the event was rejected. | ||||
| 
 | ||||
|         If delete_existing is True then existing events will be purged from the | ||||
|         database before insertion. This is useful when retrying due to IntegrityError. | ||||
|         Args: | ||||
|             txn (twisted.enterprise.adbapi.Connection): db connection | ||||
|             events_and_contexts (list[(EventBase, EventContext)]): | ||||
|                 events to persist | ||||
|             backfilled (bool): True if the events were backfilled | ||||
|             delete_existing (bool): True to purge existing table rows for the | ||||
|                 events from the database. This is useful when retrying due to | ||||
|                 IntegrityError. | ||||
|             current_state_for_room (dict[str, (list[str], list[str])]): | ||||
|                 The current-state delta for each room. For each room, a tuple | ||||
|                 (to_delete, to_insert), being a list of event ids to be removed | ||||
|                 from the current state, and a list of event ids to be added to | ||||
|                 the current state. | ||||
|             new_forward_extremeties (dict[str, list[str]]): | ||||
|                 The new forward extremities for each room. For each room, a | ||||
|                 list of the event ids which are the forward extremities. | ||||
| 
 | ||||
|         """ | ||||
|         self._update_current_state_txn(txn, current_state_for_room) | ||||
| 
 | ||||
|         max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering | ||||
|         for room_id, current_state_tuple in current_state_for_room.iteritems(): | ||||
|         self._update_forward_extremities_txn( | ||||
|             txn, | ||||
|             new_forward_extremities=new_forward_extremeties, | ||||
|             max_stream_order=max_stream_order, | ||||
|         ) | ||||
| 
 | ||||
|         # Ensure that we don't have the same event twice. | ||||
|         events_and_contexts = self._filter_events_and_contexts_for_duplicates( | ||||
|             events_and_contexts, | ||||
|         ) | ||||
| 
 | ||||
|         self._update_room_depths_txn( | ||||
|             txn, | ||||
|             events_and_contexts=events_and_contexts, | ||||
|             backfilled=backfilled, | ||||
|         ) | ||||
| 
 | ||||
|         # _update_outliers_txn filters out any events which have already been | ||||
|         # persisted, and returns the filtered list. | ||||
|         events_and_contexts = self._update_outliers_txn( | ||||
|             txn, | ||||
|             events_and_contexts=events_and_contexts, | ||||
|         ) | ||||
| 
 | ||||
|         # From this point onwards the events are only events that we haven't | ||||
|         # seen before. | ||||
| 
 | ||||
|         if delete_existing: | ||||
|             # For paranoia reasons, we go and delete all the existing entries | ||||
|             # for these events so we can reinsert them. | ||||
|             # This gets around any problems with some tables already having | ||||
|             # entries. | ||||
|             self._delete_existing_rows_txn( | ||||
|                 txn, | ||||
|                 events_and_contexts=events_and_contexts, | ||||
|             ) | ||||
| 
 | ||||
|         self._store_event_txn( | ||||
|             txn, | ||||
|             events_and_contexts=events_and_contexts, | ||||
|         ) | ||||
| 
 | ||||
|         # Insert into the state_groups, state_groups_state, and | ||||
|         # event_to_state_groups tables. | ||||
|         self._store_mult_state_groups_txn(txn, events_and_contexts) | ||||
| 
 | ||||
|         # _store_rejected_events_txn filters out any events which were | ||||
|         # rejected, and returns the filtered list. | ||||
|         events_and_contexts = self._store_rejected_events_txn( | ||||
|             txn, | ||||
|             events_and_contexts=events_and_contexts, | ||||
|         ) | ||||
| 
 | ||||
|         # From this point onwards the events are only ones that weren't | ||||
|         # rejected. | ||||
| 
 | ||||
|         self._update_metadata_tables_txn( | ||||
|             txn, | ||||
|             events_and_contexts=events_and_contexts, | ||||
|             backfilled=backfilled, | ||||
|         ) | ||||
| 
 | ||||
|     def _update_current_state_txn(self, txn, state_delta_by_room): | ||||
|         for room_id, current_state_tuple in state_delta_by_room.iteritems(): | ||||
|                 to_delete, to_insert = current_state_tuple | ||||
|                 txn.executemany( | ||||
|                     "DELETE FROM current_state_events WHERE event_id = ?", | ||||
| @ -585,7 +727,13 @@ class EventsStore(SQLBaseStore): | ||||
|                     txn, self.get_users_in_room, (room_id,) | ||||
|                 ) | ||||
| 
 | ||||
|         for room_id, new_extrem in new_forward_extremeties.items(): | ||||
|                 self._invalidate_cache_and_stream( | ||||
|                     txn, self.get_current_state_ids, (room_id,) | ||||
|                 ) | ||||
| 
 | ||||
|     def _update_forward_extremities_txn(self, txn, new_forward_extremities, | ||||
|                                         max_stream_order): | ||||
|         for room_id, new_extrem in new_forward_extremities.iteritems(): | ||||
|             self._simple_delete_txn( | ||||
|                 txn, | ||||
|                 table="event_forward_extremities", | ||||
| @ -603,7 +751,7 @@ class EventsStore(SQLBaseStore): | ||||
|                     "event_id": ev_id, | ||||
|                     "room_id": room_id, | ||||
|                 } | ||||
|                 for room_id, new_extrem in new_forward_extremeties.items() | ||||
|                 for room_id, new_extrem in new_forward_extremities.iteritems() | ||||
|                 for ev_id in new_extrem | ||||
|             ], | ||||
|         ) | ||||
| @ -620,13 +768,22 @@ class EventsStore(SQLBaseStore): | ||||
|                     "event_id": event_id, | ||||
|                     "stream_ordering": max_stream_order, | ||||
|                 } | ||||
|                 for room_id, new_extrem in new_forward_extremeties.items() | ||||
|                 for room_id, new_extrem in new_forward_extremities.iteritems() | ||||
|                 for event_id in new_extrem | ||||
|             ] | ||||
|         ) | ||||
| 
 | ||||
|         # Ensure that we don't have the same event twice. | ||||
|         # Pick the earliest non-outlier if there is one, else the earliest one. | ||||
|     @classmethod | ||||
|     def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts): | ||||
|         """Ensure that we don't have the same event twice. | ||||
| 
 | ||||
|         Pick the earliest non-outlier if there is one, else the earliest one. | ||||
| 
 | ||||
|         Args: | ||||
|             events_and_contexts (list[(EventBase, EventContext)]): | ||||
|         Returns: | ||||
|             list[(EventBase, EventContext)]: filtered list | ||||
|         """ | ||||
|         new_events_and_contexts = OrderedDict() | ||||
|         for event, context in events_and_contexts: | ||||
|             prev_event_context = new_events_and_contexts.get(event.event_id) | ||||
| @ -639,9 +796,17 @@ class EventsStore(SQLBaseStore): | ||||
|                         new_events_and_contexts[event.event_id] = (event, context) | ||||
|             else: | ||||
|                 new_events_and_contexts[event.event_id] = (event, context) | ||||
|         return new_events_and_contexts.values() | ||||
| 
 | ||||
|         events_and_contexts = new_events_and_contexts.values() | ||||
|     def _update_room_depths_txn(self, txn, events_and_contexts, backfilled): | ||||
|         """Update min_depth for each room | ||||
| 
 | ||||
|         Args: | ||||
|             txn (twisted.enterprise.adbapi.Connection): db connection | ||||
|             events_and_contexts (list[(EventBase, EventContext)]): events | ||||
|                 we are persisting | ||||
|             backfilled (bool): True if the events were backfilled | ||||
|         """ | ||||
|         depth_updates = {} | ||||
|         for event, context in events_and_contexts: | ||||
|             # Remove the any existing cache entries for the event_ids | ||||
| @ -657,9 +822,24 @@ class EventsStore(SQLBaseStore): | ||||
|                     event.depth, depth_updates.get(event.room_id, event.depth) | ||||
|                 ) | ||||
| 
 | ||||
|         for room_id, depth in depth_updates.items(): | ||||
|         for room_id, depth in depth_updates.iteritems(): | ||||
|             self._update_min_depth_for_room_txn(txn, room_id, depth) | ||||
| 
 | ||||
|     def _update_outliers_txn(self, txn, events_and_contexts): | ||||
|         """Update any outliers with new event info. | ||||
| 
 | ||||
|         This turns outliers into ex-outliers (unless the new event was | ||||
|         rejected). | ||||
| 
 | ||||
|         Args: | ||||
|             txn (twisted.enterprise.adbapi.Connection): db connection | ||||
|             events_and_contexts (list[(EventBase, EventContext)]): events | ||||
|                 we are persisting | ||||
| 
 | ||||
|         Returns: | ||||
|             list[(EventBase, EventContext)] new list, without events which | ||||
|             are already in the events table. | ||||
|         """ | ||||
|         txn.execute( | ||||
|             "SELECT event_id, outlier FROM events WHERE event_id in (%s)" % ( | ||||
|                 ",".join(["?"] * len(events_and_contexts)), | ||||
| @ -669,24 +849,21 @@ class EventsStore(SQLBaseStore): | ||||
| 
 | ||||
|         have_persisted = { | ||||
|             event_id: outlier | ||||
|             for event_id, outlier in txn.fetchall() | ||||
|             for event_id, outlier in txn | ||||
|         } | ||||
| 
 | ||||
|         to_remove = set() | ||||
|         for event, context in events_and_contexts: | ||||
|             if context.rejected: | ||||
|                 # If the event is rejected then we don't care if the event | ||||
|                 # was an outlier or not. | ||||
|                 if event.event_id in have_persisted: | ||||
|                     # If we have already seen the event then ignore it. | ||||
|                     to_remove.add(event) | ||||
|                 continue | ||||
| 
 | ||||
|             if event.event_id not in have_persisted: | ||||
|                 continue | ||||
| 
 | ||||
|             to_remove.add(event) | ||||
| 
 | ||||
|             if context.rejected: | ||||
|                 # If the event is rejected then we don't care if the event | ||||
|                 # was an outlier or not. | ||||
|                 continue | ||||
| 
 | ||||
|             outlier_persisted = have_persisted[event.event_id] | ||||
|             if not event.internal_metadata.is_outlier() and outlier_persisted: | ||||
|                 # We received a copy of an event that we had already stored as | ||||
| @ -741,37 +918,19 @@ class EventsStore(SQLBaseStore): | ||||
|                 # event isn't an outlier any more. | ||||
|                 self._update_backward_extremeties(txn, [event]) | ||||
| 
 | ||||
|         events_and_contexts = [ | ||||
|         return [ | ||||
|             ec for ec in events_and_contexts if ec[0] not in to_remove | ||||
|         ] | ||||
| 
 | ||||
|     @classmethod | ||||
|     def _delete_existing_rows_txn(cls, txn, events_and_contexts): | ||||
|         if not events_and_contexts: | ||||
|             # Make sure we don't pass an empty list to functions that expect to | ||||
|             # be storing at least one element. | ||||
|             # nothing to do here | ||||
|             return | ||||
| 
 | ||||
|         # From this point onwards the events are only events that we haven't | ||||
|         # seen before. | ||||
|         logger.info("Deleting existing") | ||||
| 
 | ||||
|         def event_dict(event): | ||||
|             return { | ||||
|                 k: v | ||||
|                 for k, v in event.get_dict().items() | ||||
|                 if k not in [ | ||||
|                     "redacted", | ||||
|                     "redacted_because", | ||||
|                 ] | ||||
|             } | ||||
| 
 | ||||
|         if delete_existing: | ||||
|             # For paranoia reasons, we go and delete all the existing entries | ||||
|             # for these events so we can reinsert them. | ||||
|             # This gets around any problems with some tables already having | ||||
|             # entries. | ||||
| 
 | ||||
|             logger.info("Deleting existing") | ||||
| 
 | ||||
|             for table in ( | ||||
|         for table in ( | ||||
|                 "events", | ||||
|                 "event_auth", | ||||
|                 "event_json", | ||||
| @ -794,11 +953,30 @@ class EventsStore(SQLBaseStore): | ||||
|                 "redactions", | ||||
|                 "room_memberships", | ||||
|                 "topics" | ||||
|             ): | ||||
|                 txn.executemany( | ||||
|                     "DELETE FROM %s WHERE event_id = ?" % (table,), | ||||
|                     [(ev.event_id,) for ev, _ in events_and_contexts] | ||||
|                 ) | ||||
|         ): | ||||
|             txn.executemany( | ||||
|                 "DELETE FROM %s WHERE event_id = ?" % (table,), | ||||
|                 [(ev.event_id,) for ev, _ in events_and_contexts] | ||||
|             ) | ||||
| 
 | ||||
|     def _store_event_txn(self, txn, events_and_contexts): | ||||
|         """Insert new events into the event and event_json tables | ||||
| 
 | ||||
|         Args: | ||||
|             txn (twisted.enterprise.adbapi.Connection): db connection | ||||
|             events_and_contexts (list[(EventBase, EventContext)]): events | ||||
|                 we are persisting | ||||
|         """ | ||||
| 
 | ||||
|         if not events_and_contexts: | ||||
|             # nothing to do here | ||||
|             return | ||||
| 
 | ||||
|         def event_dict(event): | ||||
|             d = event.get_dict() | ||||
|             d.pop("redacted", None) | ||||
|             d.pop("redacted_because", None) | ||||
|             return d | ||||
| 
 | ||||
|         self._simple_insert_many_txn( | ||||
|             txn, | ||||
| @ -842,6 +1020,19 @@ class EventsStore(SQLBaseStore): | ||||
|             ], | ||||
|         ) | ||||
| 
 | ||||
|     def _store_rejected_events_txn(self, txn, events_and_contexts): | ||||
|         """Add rows to the 'rejections' table for received events which were | ||||
|         rejected | ||||
| 
 | ||||
|         Args: | ||||
|             txn (twisted.enterprise.adbapi.Connection): db connection | ||||
|             events_and_contexts (list[(EventBase, EventContext)]): events | ||||
|                 we are persisting | ||||
| 
 | ||||
|         Returns: | ||||
|             list[(EventBase, EventContext)] new list, without the rejected | ||||
|                 events. | ||||
|         """ | ||||
|         # Remove the rejected events from the list now that we've added them | ||||
|         # to the events table and the events_json table. | ||||
|         to_remove = set() | ||||
| @ -853,16 +1044,23 @@ class EventsStore(SQLBaseStore): | ||||
|                 ) | ||||
|                 to_remove.add(event) | ||||
| 
 | ||||
|         events_and_contexts = [ | ||||
|         return [ | ||||
|             ec for ec in events_and_contexts if ec[0] not in to_remove | ||||
|         ] | ||||
| 
 | ||||
|         if not events_and_contexts: | ||||
|             # Make sure we don't pass an empty list to functions that expect to | ||||
|             # be storing at least one element. | ||||
|             return | ||||
|     def _update_metadata_tables_txn(self, txn, events_and_contexts, backfilled): | ||||
|         """Update all the miscellaneous tables for new events | ||||
| 
 | ||||
|         # From this point onwards the events are only ones that weren't rejected. | ||||
|         Args: | ||||
|             txn (twisted.enterprise.adbapi.Connection): db connection | ||||
|             events_and_contexts (list[(EventBase, EventContext)]): events | ||||
|                 we are persisting | ||||
|             backfilled (bool): True if the events were backfilled | ||||
|         """ | ||||
| 
 | ||||
|         if not events_and_contexts: | ||||
|             # nothing to do here | ||||
|             return | ||||
| 
 | ||||
|         for event, context in events_and_contexts: | ||||
|             # Insert all the push actions into the event_push_actions table. | ||||
| @ -892,10 +1090,6 @@ class EventsStore(SQLBaseStore): | ||||
|             ], | ||||
|         ) | ||||
| 
 | ||||
|         # Insert into the state_groups, state_groups_state, and | ||||
|         # event_to_state_groups tables. | ||||
|         self._store_mult_state_groups_txn(txn, events_and_contexts) | ||||
| 
 | ||||
|         # Update the event_forward_extremities, event_backward_extremities and | ||||
|         # event_edges tables. | ||||
|         self._handle_mult_prev_events( | ||||
| @ -982,13 +1176,6 @@ class EventsStore(SQLBaseStore): | ||||
|         # Prefill the event cache | ||||
|         self._add_to_cache(txn, events_and_contexts) | ||||
| 
 | ||||
|         if backfilled: | ||||
|             # Backfilled events come before the current state so we don't need | ||||
|             # to update the current state table | ||||
|             return | ||||
| 
 | ||||
|         return | ||||
| 
 | ||||
|     def _add_to_cache(self, txn, events_and_contexts): | ||||
|         to_prefill = [] | ||||
| 
 | ||||
| @ -1597,14 +1784,13 @@ class EventsStore(SQLBaseStore): | ||||
| 
 | ||||
|         def get_all_new_events_txn(txn): | ||||
|             sql = ( | ||||
|                 "SELECT e.stream_ordering, ej.internal_metadata, ej.json, eg.state_group" | ||||
|                 " FROM events as e" | ||||
|                 " JOIN event_json as ej" | ||||
|                 " ON e.event_id = ej.event_id AND e.room_id = ej.room_id" | ||||
|                 " LEFT JOIN event_to_state_groups as eg" | ||||
|                 " ON e.event_id = eg.event_id" | ||||
|                 " WHERE ? < e.stream_ordering AND e.stream_ordering <= ?" | ||||
|                 " ORDER BY e.stream_ordering ASC" | ||||
|                 "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," | ||||
|                 " state_key, redacts" | ||||
|                 " FROM events AS e" | ||||
|                 " LEFT JOIN redactions USING (event_id)" | ||||
|                 " LEFT JOIN state_events USING (event_id)" | ||||
|                 " WHERE ? < stream_ordering AND stream_ordering <= ?" | ||||
|                 " ORDER BY stream_ordering ASC" | ||||
|                 " LIMIT ?" | ||||
|             ) | ||||
|             if have_forward_events: | ||||
| @ -1630,15 +1816,13 @@ class EventsStore(SQLBaseStore): | ||||
|                 forward_ex_outliers = [] | ||||
| 
 | ||||
|             sql = ( | ||||
|                 "SELECT -e.stream_ordering, ej.internal_metadata, ej.json," | ||||
|                 " eg.state_group" | ||||
|                 " FROM events as e" | ||||
|                 " JOIN event_json as ej" | ||||
|                 " ON e.event_id = ej.event_id AND e.room_id = ej.room_id" | ||||
|                 " LEFT JOIN event_to_state_groups as eg" | ||||
|                 " ON e.event_id = eg.event_id" | ||||
|                 " WHERE ? > e.stream_ordering AND e.stream_ordering >= ?" | ||||
|                 " ORDER BY e.stream_ordering DESC" | ||||
|                 "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," | ||||
|                 " state_key, redacts" | ||||
|                 " FROM events AS e" | ||||
|                 " LEFT JOIN redactions USING (event_id)" | ||||
|                 " LEFT JOIN state_events USING (event_id)" | ||||
|                 " WHERE ? > stream_ordering AND stream_ordering >= ?" | ||||
|                 " ORDER BY stream_ordering DESC" | ||||
|                 " LIMIT ?" | ||||
|             ) | ||||
|             if have_backfill_events: | ||||
| @ -1825,7 +2009,7 @@ class EventsStore(SQLBaseStore): | ||||
|                         "state_key": key[1], | ||||
|                         "event_id": state_id, | ||||
|                     } | ||||
|                     for key, state_id in curr_state.items() | ||||
|                     for key, state_id in curr_state.iteritems() | ||||
|                 ], | ||||
|             ) | ||||
| 
 | ||||
|  | ||||
| @ -101,9 +101,10 @@ class KeyStore(SQLBaseStore): | ||||
|         key_ids | ||||
|         Args: | ||||
|             server_name (str): The name of the server. | ||||
|             key_ids (list of str): List of key_ids to try and look up. | ||||
|             key_ids (iterable[str]): key_ids to try and look up. | ||||
|         Returns: | ||||
|             (list of VerifyKey): The verification keys. | ||||
|             Deferred: resolves to dict[str, VerifyKey]: map from | ||||
|                key_id to verification key. | ||||
|         """ | ||||
|         keys = {} | ||||
|         for key_id in key_ids: | ||||
|  | ||||
| @ -356,7 +356,7 @@ def _get_or_create_schema_state(txn, database_engine): | ||||
|             ), | ||||
|             (current_version,) | ||||
|         ) | ||||
|         applied_deltas = [d for d, in txn.fetchall()] | ||||
|         applied_deltas = [d for d, in txn] | ||||
|         return current_version, applied_deltas, upgraded | ||||
| 
 | ||||
|     return None | ||||
|  | ||||
| @ -85,8 +85,8 @@ class PresenceStore(SQLBaseStore): | ||||
|                 self.presence_stream_cache.entity_has_changed, | ||||
|                 state.user_id, stream_id, | ||||
|             ) | ||||
|             self._invalidate_cache_and_stream( | ||||
|                 txn, self._get_presence_for_user, (state.user_id,) | ||||
|             txn.call_after( | ||||
|                 self._get_presence_for_user.invalidate, (state.user_id,) | ||||
|             ) | ||||
| 
 | ||||
|         # Actually insert new rows | ||||
|  | ||||
| @ -313,10 +313,9 @@ class ReceiptsStore(SQLBaseStore): | ||||
|         ) | ||||
| 
 | ||||
|         txn.execute(sql, (room_id, receipt_type, user_id)) | ||||
|         results = txn.fetchall() | ||||
| 
 | ||||
|         if results and topological_ordering: | ||||
|             for to, so, _ in results: | ||||
|         if topological_ordering: | ||||
|             for to, so, _ in txn: | ||||
|                 if int(to) > topological_ordering: | ||||
|                     return False | ||||
|                 elif int(to) == topological_ordering and int(so) >= stream_ordering: | ||||
|  | ||||
| @ -209,7 +209,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): | ||||
|                 " WHERE lower(name) = lower(?)" | ||||
|             ) | ||||
|             txn.execute(sql, (user_id,)) | ||||
|             return dict(txn.fetchall()) | ||||
|             return dict(txn) | ||||
| 
 | ||||
|         return self.runInteraction("get_users_by_id_case_insensitive", f) | ||||
| 
 | ||||
|  | ||||
| @ -396,7 +396,7 @@ class RoomStore(SQLBaseStore): | ||||
|                     sql % ("AND appservice_id IS NULL",), | ||||
|                     (stream_id,) | ||||
|                 ) | ||||
|             return dict(txn.fetchall()) | ||||
|             return dict(txn) | ||||
|         else: | ||||
|             # We want to get from all lists, so we need to aggregate the results | ||||
| 
 | ||||
| @ -422,7 +422,7 @@ class RoomStore(SQLBaseStore): | ||||
| 
 | ||||
|             results = {} | ||||
|             # A room is visible if its visible on any list. | ||||
|             for room_id, visibility in txn.fetchall(): | ||||
|             for room_id, visibility in txn: | ||||
|                 results[room_id] = bool(visibility) or results.get(room_id, False) | ||||
| 
 | ||||
|             return results | ||||
|  | ||||
| @ -129,17 +129,30 @@ class RoomMemberStore(SQLBaseStore): | ||||
|         with self._stream_id_gen.get_next() as stream_ordering: | ||||
|             yield self.runInteraction("locally_reject_invite", f, stream_ordering) | ||||
| 
 | ||||
|     @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True) | ||||
|     def get_hosts_in_room(self, room_id, cache_context): | ||||
|         """Returns the set of all hosts currently in the room | ||||
|         """ | ||||
|         user_ids = yield self.get_users_in_room( | ||||
|             room_id, on_invalidate=cache_context.invalidate, | ||||
|         ) | ||||
|         hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids) | ||||
|         defer.returnValue(hosts) | ||||
| 
 | ||||
|     @cached(max_entries=500000, iterable=True) | ||||
|     def get_users_in_room(self, room_id): | ||||
|         def f(txn): | ||||
| 
 | ||||
|             rows = self._get_members_rows_txn( | ||||
|                 txn, | ||||
|                 room_id=room_id, | ||||
|                 membership=Membership.JOIN, | ||||
|             sql = ( | ||||
|                 "SELECT m.user_id FROM room_memberships as m" | ||||
|                 " INNER JOIN current_state_events as c" | ||||
|                 " ON m.event_id = c.event_id " | ||||
|                 " AND m.room_id = c.room_id " | ||||
|                 " AND m.user_id = c.state_key" | ||||
|                 " WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?" | ||||
|             ) | ||||
| 
 | ||||
|             return [r["user_id"] for r in rows] | ||||
|             txn.execute(sql, (room_id, Membership.JOIN,)) | ||||
|             return [r[0] for r in txn] | ||||
|         return self.runInteraction("get_users_in_room", f) | ||||
| 
 | ||||
|     @cached() | ||||
| @ -246,52 +259,27 @@ class RoomMemberStore(SQLBaseStore): | ||||
| 
 | ||||
|         return results | ||||
| 
 | ||||
|     def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None): | ||||
|         where_clause = "c.room_id = ?" | ||||
|         where_values = [room_id] | ||||
| 
 | ||||
|         if membership: | ||||
|             where_clause += " AND m.membership = ?" | ||||
|             where_values.append(membership) | ||||
| 
 | ||||
|         if user_id: | ||||
|             where_clause += " AND m.user_id = ?" | ||||
|             where_values.append(user_id) | ||||
| 
 | ||||
|         sql = ( | ||||
|             "SELECT m.* FROM room_memberships as m" | ||||
|             " INNER JOIN current_state_events as c" | ||||
|             " ON m.event_id = c.event_id " | ||||
|             " AND m.room_id = c.room_id " | ||||
|             " AND m.user_id = c.state_key" | ||||
|             " WHERE c.type = 'm.room.member' AND %(where)s" | ||||
|         ) % { | ||||
|             "where": where_clause, | ||||
|         } | ||||
| 
 | ||||
|         txn.execute(sql, where_values) | ||||
|         rows = self.cursor_to_dict(txn) | ||||
| 
 | ||||
|         return rows | ||||
| 
 | ||||
|     @cached(max_entries=500000, iterable=True) | ||||
|     @cachedInlineCallbacks(max_entries=500000, iterable=True) | ||||
|     def get_rooms_for_user(self, user_id): | ||||
|         return self.get_rooms_for_user_where_membership_is( | ||||
|         """Returns a set of room_ids the user is currently joined to | ||||
|         """ | ||||
|         rooms = yield self.get_rooms_for_user_where_membership_is( | ||||
|             user_id, membership_list=[Membership.JOIN], | ||||
|         ) | ||||
|         defer.returnValue(frozenset(r.room_id for r in rooms)) | ||||
| 
 | ||||
|     @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True) | ||||
|     def get_users_who_share_room_with_user(self, user_id, cache_context): | ||||
|         """Returns the set of users who share a room with `user_id` | ||||
|         """ | ||||
|         rooms = yield self.get_rooms_for_user( | ||||
|         room_ids = yield self.get_rooms_for_user( | ||||
|             user_id, on_invalidate=cache_context.invalidate, | ||||
|         ) | ||||
| 
 | ||||
|         user_who_share_room = set() | ||||
|         for room in rooms: | ||||
|         for room_id in room_ids: | ||||
|             user_ids = yield self.get_users_in_room( | ||||
|                 room.room_id, on_invalidate=cache_context.invalidate, | ||||
|                 room_id, on_invalidate=cache_context.invalidate, | ||||
|             ) | ||||
|             user_who_share_room.update(user_ids) | ||||
| 
 | ||||
|  | ||||
| @ -72,7 +72,7 @@ class SignatureStore(SQLBaseStore): | ||||
|             " WHERE event_id = ?" | ||||
|         ) | ||||
|         txn.execute(query, (event_id, )) | ||||
|         return {k: v for k, v in txn.fetchall()} | ||||
|         return {k: v for k, v in txn} | ||||
| 
 | ||||
|     def _store_event_reference_hashes_txn(self, txn, events): | ||||
|         """Store a hash for a PDU | ||||
|  | ||||
| @ -14,7 +14,7 @@ | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from ._base import SQLBaseStore | ||||
| from synapse.util.caches.descriptors import cached, cachedList | ||||
| from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks | ||||
| from synapse.util.caches import intern_string | ||||
| from synapse.storage.engines import PostgresEngine | ||||
| 
 | ||||
| @ -69,6 +69,18 @@ class StateStore(SQLBaseStore): | ||||
|             where_clause="type='m.room.member'", | ||||
|         ) | ||||
| 
 | ||||
|     @cachedInlineCallbacks(max_entries=100000, iterable=True) | ||||
|     def get_current_state_ids(self, room_id): | ||||
|         rows = yield self._simple_select_list( | ||||
|             table="current_state_events", | ||||
|             keyvalues={"room_id": room_id}, | ||||
|             retcols=["event_id", "type", "state_key"], | ||||
|             desc="_calculate_state_delta", | ||||
|         ) | ||||
|         defer.returnValue({ | ||||
|             (r["type"], r["state_key"]): r["event_id"] for r in rows | ||||
|         }) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_state_groups_ids(self, room_id, event_ids): | ||||
|         if not event_ids: | ||||
| @ -78,7 +90,7 @@ class StateStore(SQLBaseStore): | ||||
|             event_ids, | ||||
|         ) | ||||
| 
 | ||||
|         groups = set(event_to_groups.values()) | ||||
|         groups = set(event_to_groups.itervalues()) | ||||
|         group_to_state = yield self._get_state_for_groups(groups) | ||||
| 
 | ||||
|         defer.returnValue(group_to_state) | ||||
| @ -96,17 +108,18 @@ class StateStore(SQLBaseStore): | ||||
| 
 | ||||
|         state_event_map = yield self.get_events( | ||||
|             [ | ||||
|                 ev_id for group_ids in group_to_ids.values() | ||||
|                 for ev_id in group_ids.values() | ||||
|                 ev_id for group_ids in group_to_ids.itervalues() | ||||
|                 for ev_id in group_ids.itervalues() | ||||
|             ], | ||||
|             get_prev_content=False | ||||
|         ) | ||||
| 
 | ||||
|         defer.returnValue({ | ||||
|             group: [ | ||||
|                 state_event_map[v] for v in event_id_map.values() if v in state_event_map | ||||
|                 state_event_map[v] for v in event_id_map.itervalues() | ||||
|                 if v in state_event_map | ||||
|             ] | ||||
|             for group, event_id_map in group_to_ids.items() | ||||
|             for group, event_id_map in group_to_ids.iteritems() | ||||
|         }) | ||||
| 
 | ||||
|     def _have_persisted_state_group_txn(self, txn, state_group): | ||||
| @ -124,6 +137,16 @@ class StateStore(SQLBaseStore): | ||||
|                 continue | ||||
| 
 | ||||
|             if context.current_state_ids is None: | ||||
|                 # AFAIK, this can never happen | ||||
|                 logger.error( | ||||
|                     "Non-outlier event %s had current_state_ids==None", | ||||
|                     event.event_id) | ||||
|                 continue | ||||
| 
 | ||||
|             # if the event was rejected, just give it the same state as its | ||||
|             # predecessor. | ||||
|             if context.rejected: | ||||
|                 state_groups[event.event_id] = context.prev_group | ||||
|                 continue | ||||
| 
 | ||||
|             state_groups[event.event_id] = context.state_group | ||||
| @ -168,7 +191,7 @@ class StateStore(SQLBaseStore): | ||||
|                             "state_key": key[1], | ||||
|                             "event_id": state_id, | ||||
|                         } | ||||
|                         for key, state_id in context.delta_ids.items() | ||||
|                         for key, state_id in context.delta_ids.iteritems() | ||||
|                     ], | ||||
|                 ) | ||||
|             else: | ||||
| @ -183,7 +206,7 @@ class StateStore(SQLBaseStore): | ||||
|                             "state_key": key[1], | ||||
|                             "event_id": state_id, | ||||
|                         } | ||||
|                         for key, state_id in context.current_state_ids.items() | ||||
|                         for key, state_id in context.current_state_ids.iteritems() | ||||
|                     ], | ||||
|                 ) | ||||
| 
 | ||||
| @ -195,7 +218,7 @@ class StateStore(SQLBaseStore): | ||||
|                     "state_group": state_group_id, | ||||
|                     "event_id": event_id, | ||||
|                 } | ||||
|                 for event_id, state_group_id in state_groups.items() | ||||
|                 for event_id, state_group_id in state_groups.iteritems() | ||||
|             ], | ||||
|         ) | ||||
| 
 | ||||
| @ -319,10 +342,10 @@ class StateStore(SQLBaseStore): | ||||
|                     args.extend(where_args) | ||||
| 
 | ||||
|                     txn.execute(sql % (where_clause,), args) | ||||
|                     rows = self.cursor_to_dict(txn) | ||||
|                     for row in rows: | ||||
|                         key = (row["type"], row["state_key"]) | ||||
|                         results[group][key] = row["event_id"] | ||||
|                     for row in txn: | ||||
|                         typ, state_key, event_id = row | ||||
|                         key = (typ, state_key) | ||||
|                         results[group][key] = event_id | ||||
|         else: | ||||
|             if types is not None: | ||||
|                 where_clause = "AND (%s)" % ( | ||||
| @ -351,12 +374,11 @@ class StateStore(SQLBaseStore): | ||||
|                         " WHERE state_group = ? %s" % (where_clause,), | ||||
|                         args | ||||
|                     ) | ||||
|                     rows = txn.fetchall() | ||||
|                     results[group].update({ | ||||
|                         (typ, state_key): event_id | ||||
|                         for typ, state_key, event_id in rows | ||||
|                     results[group].update( | ||||
|                         ((typ, state_key), event_id) | ||||
|                         for typ, state_key, event_id in txn | ||||
|                         if (typ, state_key) not in results[group] | ||||
|                     }) | ||||
|                     ) | ||||
| 
 | ||||
|                     # If the lengths match then we must have all the types, | ||||
|                     # so no need to go walk further down the tree. | ||||
| @ -393,21 +415,21 @@ class StateStore(SQLBaseStore): | ||||
|             event_ids, | ||||
|         ) | ||||
| 
 | ||||
|         groups = set(event_to_groups.values()) | ||||
|         groups = set(event_to_groups.itervalues()) | ||||
|         group_to_state = yield self._get_state_for_groups(groups, types) | ||||
| 
 | ||||
|         state_event_map = yield self.get_events( | ||||
|             [ev_id for sd in group_to_state.values() for ev_id in sd.values()], | ||||
|             [ev_id for sd in group_to_state.itervalues() for ev_id in sd.itervalues()], | ||||
|             get_prev_content=False | ||||
|         ) | ||||
| 
 | ||||
|         event_to_state = { | ||||
|             event_id: { | ||||
|                 k: state_event_map[v] | ||||
|                 for k, v in group_to_state[group].items() | ||||
|                 for k, v in group_to_state[group].iteritems() | ||||
|                 if v in state_event_map | ||||
|             } | ||||
|             for event_id, group in event_to_groups.items() | ||||
|             for event_id, group in event_to_groups.iteritems() | ||||
|         } | ||||
| 
 | ||||
|         defer.returnValue({event: event_to_state[event] for event in event_ids}) | ||||
| @ -430,12 +452,12 @@ class StateStore(SQLBaseStore): | ||||
|             event_ids, | ||||
|         ) | ||||
| 
 | ||||
|         groups = set(event_to_groups.values()) | ||||
|         groups = set(event_to_groups.itervalues()) | ||||
|         group_to_state = yield self._get_state_for_groups(groups, types) | ||||
| 
 | ||||
|         event_to_state = { | ||||
|             event_id: group_to_state[group] | ||||
|             for event_id, group in event_to_groups.items() | ||||
|             for event_id, group in event_to_groups.iteritems() | ||||
|         } | ||||
| 
 | ||||
|         defer.returnValue({event: event_to_state[event] for event in event_ids}) | ||||
| @ -474,7 +496,7 @@ class StateStore(SQLBaseStore): | ||||
|         state_map = yield self.get_state_ids_for_events([event_id], types) | ||||
|         defer.returnValue(state_map[event_id]) | ||||
| 
 | ||||
|     @cached(num_args=2, max_entries=10000) | ||||
|     @cached(num_args=2, max_entries=100000) | ||||
|     def _get_state_group_for_event(self, room_id, event_id): | ||||
|         return self._simple_select_one_onecol( | ||||
|             table="event_to_state_groups", | ||||
| @ -547,7 +569,7 @@ class StateStore(SQLBaseStore): | ||||
|         got_all = not (missing_types or types is None) | ||||
| 
 | ||||
|         return { | ||||
|             k: v for k, v in state_dict_ids.items() | ||||
|             k: v for k, v in state_dict_ids.iteritems() | ||||
|             if include(k[0], k[1]) | ||||
|         }, missing_types, got_all | ||||
| 
 | ||||
| @ -606,7 +628,7 @@ class StateStore(SQLBaseStore): | ||||
| 
 | ||||
|             # Now we want to update the cache with all the things we fetched | ||||
|             # from the database. | ||||
|             for group, group_state_dict in group_to_state_dict.items(): | ||||
|             for group, group_state_dict in group_to_state_dict.iteritems(): | ||||
|                 if types: | ||||
|                     # We delibrately put key -> None mappings into the cache to | ||||
|                     # cache absence of the key, on the assumption that if we've | ||||
| @ -621,10 +643,10 @@ class StateStore(SQLBaseStore): | ||||
|                 else: | ||||
|                     state_dict = results[group] | ||||
| 
 | ||||
|                 state_dict.update({ | ||||
|                     (intern_string(k[0]), intern_string(k[1])): v | ||||
|                     for k, v in group_state_dict.items() | ||||
|                 }) | ||||
|                 state_dict.update( | ||||
|                     ((intern_string(k[0]), intern_string(k[1])), v) | ||||
|                     for k, v in group_state_dict.iteritems() | ||||
|                 ) | ||||
| 
 | ||||
|                 self._state_group_cache.update( | ||||
|                     cache_seq_num, | ||||
| @ -635,10 +657,10 @@ class StateStore(SQLBaseStore): | ||||
| 
 | ||||
|         # Remove all the entries with None values. The None values were just | ||||
|         # used for bookkeeping in the cache. | ||||
|         for group, state_dict in results.items(): | ||||
|         for group, state_dict in results.iteritems(): | ||||
|             results[group] = { | ||||
|                 key: event_id | ||||
|                 for key, event_id in state_dict.items() | ||||
|                 for key, event_id in state_dict.iteritems() | ||||
|                 if event_id | ||||
|             } | ||||
| 
 | ||||
| @ -727,7 +749,7 @@ class StateStore(SQLBaseStore): | ||||
|                         # of keys | ||||
| 
 | ||||
|                         delta_state = { | ||||
|                             key: value for key, value in curr_state.items() | ||||
|                             key: value for key, value in curr_state.iteritems() | ||||
|                             if prev_state.get(key, None) != value | ||||
|                         } | ||||
| 
 | ||||
| @ -767,7 +789,7 @@ class StateStore(SQLBaseStore): | ||||
|                                     "state_key": key[1], | ||||
|                                     "event_id": state_id, | ||||
|                                 } | ||||
|                                 for key, state_id in delta_state.items() | ||||
|                                 for key, state_id in delta_state.iteritems() | ||||
|                             ], | ||||
|                         ) | ||||
| 
 | ||||
|  | ||||
| @ -829,3 +829,6 @@ class StreamStore(SQLBaseStore): | ||||
|             updatevalues={"stream_id": stream_id}, | ||||
|             desc="update_federation_out_pos", | ||||
|         ) | ||||
| 
 | ||||
|     def has_room_changed_since(self, room_id, stream_id): | ||||
|         return self._events_stream_cache.has_entity_changed(room_id, stream_id) | ||||
|  | ||||
| @ -95,7 +95,7 @@ class TagsStore(SQLBaseStore): | ||||
|             for stream_id, user_id, room_id in tag_ids: | ||||
|                 txn.execute(sql, (user_id, room_id)) | ||||
|                 tags = [] | ||||
|                 for tag, content in txn.fetchall(): | ||||
|                 for tag, content in txn: | ||||
|                     tags.append(json.dumps(tag) + ":" + content) | ||||
|                 tag_json = "{" + ",".join(tags) + "}" | ||||
|                 results.append((stream_id, user_id, room_id, tag_json)) | ||||
| @ -132,7 +132,7 @@ class TagsStore(SQLBaseStore): | ||||
|                 " WHERE user_id = ? AND stream_id > ?" | ||||
|             ) | ||||
|             txn.execute(sql, (user_id, stream_id)) | ||||
|             room_ids = [row[0] for row in txn.fetchall()] | ||||
|             room_ids = [row[0] for row in txn] | ||||
|             return room_ids | ||||
| 
 | ||||
|         changed = self._account_data_stream_cache.has_entity_changed( | ||||
|  | ||||
| @ -30,6 +30,17 @@ class IdGenerator(object): | ||||
| 
 | ||||
| 
 | ||||
| def _load_current_id(db_conn, table, column, step=1): | ||||
|     """ | ||||
| 
 | ||||
|     Args: | ||||
|         db_conn (object): | ||||
|         table (str): | ||||
|         column (str): | ||||
|         step (int): | ||||
| 
 | ||||
|     Returns: | ||||
|         int | ||||
|     """ | ||||
|     cur = db_conn.cursor() | ||||
|     if step == 1: | ||||
|         cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) | ||||
| @ -131,6 +142,9 @@ class StreamIdGenerator(object): | ||||
|     def get_current_token(self): | ||||
|         """Returns the maximum stream id such that all stream ids less than or | ||||
|         equal to it have been successfully persisted. | ||||
| 
 | ||||
|         Returns: | ||||
|             int | ||||
|         """ | ||||
|         with self._lock: | ||||
|             if self._unfinished_ids: | ||||
|  | ||||
| @ -26,7 +26,7 @@ logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| class DeferredTimedOutError(SynapseError): | ||||
|     def __init__(self): | ||||
|         super(SynapseError).__init__(504, "Timed out") | ||||
|         super(SynapseError, self).__init__(504, "Timed out") | ||||
| 
 | ||||
| 
 | ||||
| def unwrapFirstError(failure): | ||||
| @ -93,8 +93,10 @@ class Clock(object): | ||||
|         ret_deferred = defer.Deferred() | ||||
| 
 | ||||
|         def timed_out_fn(): | ||||
|             e = DeferredTimedOutError() | ||||
| 
 | ||||
|             try: | ||||
|                 ret_deferred.errback(DeferredTimedOutError()) | ||||
|                 ret_deferred.errback(e) | ||||
|             except: | ||||
|                 pass | ||||
| 
 | ||||
| @ -114,7 +116,7 @@ class Clock(object): | ||||
| 
 | ||||
|         ret_deferred.addBoth(cancel) | ||||
| 
 | ||||
|         def sucess(res): | ||||
|         def success(res): | ||||
|             try: | ||||
|                 ret_deferred.callback(res) | ||||
|             except: | ||||
| @ -128,7 +130,7 @@ class Clock(object): | ||||
|             except: | ||||
|                 pass | ||||
| 
 | ||||
|         given_deferred.addCallbacks(callback=sucess, errback=err) | ||||
|         given_deferred.addCallbacks(callback=success, errback=err) | ||||
| 
 | ||||
|         timer = self.call_later(time_out, timed_out_fn) | ||||
| 
 | ||||
|  | ||||
| @ -15,12 +15,9 @@ | ||||
| import logging | ||||
| 
 | ||||
| from synapse.util.async import ObservableDeferred | ||||
| from synapse.util import unwrapFirstError | ||||
| from synapse.util import unwrapFirstError, logcontext | ||||
| from synapse.util.caches.lrucache import LruCache | ||||
| from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry | ||||
| from synapse.util.logcontext import ( | ||||
|     PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn | ||||
| ) | ||||
| 
 | ||||
| from . import DEBUG_CACHES, register_cache | ||||
| 
 | ||||
| @ -189,7 +186,55 @@ class Cache(object): | ||||
|         self.cache.clear() | ||||
| 
 | ||||
| 
 | ||||
| class CacheDescriptor(object): | ||||
| class _CacheDescriptorBase(object): | ||||
|     def __init__(self, orig, num_args, inlineCallbacks, cache_context=False): | ||||
|         self.orig = orig | ||||
| 
 | ||||
|         if inlineCallbacks: | ||||
|             self.function_to_call = defer.inlineCallbacks(orig) | ||||
|         else: | ||||
|             self.function_to_call = orig | ||||
| 
 | ||||
|         arg_spec = inspect.getargspec(orig) | ||||
|         all_args = arg_spec.args | ||||
| 
 | ||||
|         if "cache_context" in all_args: | ||||
|             if not cache_context: | ||||
|                 raise ValueError( | ||||
|                     "Cannot have a 'cache_context' arg without setting" | ||||
|                     " cache_context=True" | ||||
|                 ) | ||||
|         elif cache_context: | ||||
|             raise ValueError( | ||||
|                 "Cannot have cache_context=True without having an arg" | ||||
|                 " named `cache_context`" | ||||
|             ) | ||||
| 
 | ||||
|         if num_args is None: | ||||
|             num_args = len(all_args) - 1 | ||||
|             if cache_context: | ||||
|                 num_args -= 1 | ||||
| 
 | ||||
|         if len(all_args) < num_args + 1: | ||||
|             raise Exception( | ||||
|                 "Not enough explicit positional arguments to key off for %r: " | ||||
|                 "got %i args, but wanted %i. (@cached cannot key off *args or " | ||||
|                 "**kwargs)" | ||||
|                 % (orig.__name__, len(all_args), num_args) | ||||
|             ) | ||||
| 
 | ||||
|         self.num_args = num_args | ||||
|         self.arg_names = all_args[1:num_args + 1] | ||||
| 
 | ||||
|         if "cache_context" in self.arg_names: | ||||
|             raise Exception( | ||||
|                 "cache_context arg cannot be included among the cache keys" | ||||
|             ) | ||||
| 
 | ||||
|         self.add_cache_context = cache_context | ||||
| 
 | ||||
| 
 | ||||
| class CacheDescriptor(_CacheDescriptorBase): | ||||
|     """ A method decorator that applies a memoizing cache around the function. | ||||
| 
 | ||||
|     This caches deferreds, rather than the results themselves. Deferreds that | ||||
| @ -217,52 +262,24 @@ class CacheDescriptor(object): | ||||
|             r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate) | ||||
|             defer.returnValue(r1 + r2) | ||||
| 
 | ||||
|     Args: | ||||
|         num_args (int): number of positional arguments (excluding ``self`` and | ||||
|             ``cache_context``) to use as cache keys. Defaults to all named | ||||
|             args of the function. | ||||
|     """ | ||||
|     def __init__(self, orig, max_entries=1000, num_args=1, tree=False, | ||||
|     def __init__(self, orig, max_entries=1000, num_args=None, tree=False, | ||||
|                  inlineCallbacks=False, cache_context=False, iterable=False): | ||||
| 
 | ||||
|         super(CacheDescriptor, self).__init__( | ||||
|             orig, num_args=num_args, inlineCallbacks=inlineCallbacks, | ||||
|             cache_context=cache_context) | ||||
| 
 | ||||
|         max_entries = int(max_entries * CACHE_SIZE_FACTOR) | ||||
| 
 | ||||
|         self.orig = orig | ||||
| 
 | ||||
|         if inlineCallbacks: | ||||
|             self.function_to_call = defer.inlineCallbacks(orig) | ||||
|         else: | ||||
|             self.function_to_call = orig | ||||
| 
 | ||||
|         self.max_entries = max_entries | ||||
|         self.num_args = num_args | ||||
|         self.tree = tree | ||||
| 
 | ||||
|         self.iterable = iterable | ||||
| 
 | ||||
|         all_args = inspect.getargspec(orig) | ||||
|         self.arg_names = all_args.args[1:num_args + 1] | ||||
| 
 | ||||
|         if "cache_context" in all_args.args: | ||||
|             if not cache_context: | ||||
|                 raise ValueError( | ||||
|                     "Cannot have a 'cache_context' arg without setting" | ||||
|                     " cache_context=True" | ||||
|                 ) | ||||
|             try: | ||||
|                 self.arg_names.remove("cache_context") | ||||
|             except ValueError: | ||||
|                 pass | ||||
|         elif cache_context: | ||||
|             raise ValueError( | ||||
|                 "Cannot have cache_context=True without having an arg" | ||||
|                 " named `cache_context`" | ||||
|             ) | ||||
| 
 | ||||
|         self.add_cache_context = cache_context | ||||
| 
 | ||||
|         if len(self.arg_names) < self.num_args: | ||||
|             raise Exception( | ||||
|                 "Not enough explicit positional arguments to key off of for %r." | ||||
|                 " (@cached cannot key off of *args or **kwargs)" | ||||
|                 % (orig.__name__,) | ||||
|             ) | ||||
| 
 | ||||
|     def __get__(self, obj, objtype=None): | ||||
|         cache = Cache( | ||||
|             name=self.orig.__name__, | ||||
| @ -308,11 +325,9 @@ class CacheDescriptor(object): | ||||
|                         defer.returnValue(cached_result) | ||||
|                     observer.addCallback(check_result) | ||||
| 
 | ||||
|                 return preserve_context_over_deferred(observer) | ||||
|             except KeyError: | ||||
|                 ret = defer.maybeDeferred( | ||||
|                     preserve_context_over_fn, | ||||
|                     self.function_to_call, | ||||
|                     logcontext.preserve_fn(self.function_to_call), | ||||
|                     obj, *args, **kwargs | ||||
|                 ) | ||||
| 
 | ||||
| @ -322,10 +337,11 @@ class CacheDescriptor(object): | ||||
| 
 | ||||
|                 ret.addErrback(onErr) | ||||
| 
 | ||||
|                 ret = ObservableDeferred(ret, consumeErrors=True) | ||||
|                 cache.set(cache_key, ret, callback=invalidate_callback) | ||||
|                 result_d = ObservableDeferred(ret, consumeErrors=True) | ||||
|                 cache.set(cache_key, result_d, callback=invalidate_callback) | ||||
|                 observer = result_d.observe() | ||||
| 
 | ||||
|                 return preserve_context_over_deferred(ret.observe()) | ||||
|             return logcontext.make_deferred_yieldable(observer) | ||||
| 
 | ||||
|         wrapped.invalidate = cache.invalidate | ||||
|         wrapped.invalidate_all = cache.invalidate_all | ||||
| @ -338,48 +354,40 @@ class CacheDescriptor(object): | ||||
|         return wrapped | ||||
| 
 | ||||
| 
 | ||||
| class CacheListDescriptor(object): | ||||
| class CacheListDescriptor(_CacheDescriptorBase): | ||||
|     """Wraps an existing cache to support bulk fetching of keys. | ||||
| 
 | ||||
|     Given a list of keys it looks in the cache to find any hits, then passes | ||||
|     the list of missing keys to the wrapped fucntion. | ||||
|     the list of missing keys to the wrapped function. | ||||
| 
 | ||||
|     Once wrapped, the function returns either a Deferred which resolves to | ||||
|     the list of results, or (if all results were cached), just the list of | ||||
|     results. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, orig, cached_method_name, list_name, num_args=1, | ||||
|     def __init__(self, orig, cached_method_name, list_name, num_args=None, | ||||
|                  inlineCallbacks=False): | ||||
|         """ | ||||
|         Args: | ||||
|             orig (function) | ||||
|             method_name (str); The name of the chached method. | ||||
|             cached_method_name (str): The name of the chached method. | ||||
|             list_name (str): Name of the argument which is the bulk lookup list | ||||
|             num_args (int) | ||||
|             num_args (int): number of positional arguments (excluding ``self``, | ||||
|                 but including list_name) to use as cache keys. Defaults to all | ||||
|                 named args of the function. | ||||
|             inlineCallbacks (bool): Whether orig is a generator that should | ||||
|                 be wrapped by defer.inlineCallbacks | ||||
|         """ | ||||
|         self.orig = orig | ||||
|         super(CacheListDescriptor, self).__init__( | ||||
|             orig, num_args=num_args, inlineCallbacks=inlineCallbacks) | ||||
| 
 | ||||
|         if inlineCallbacks: | ||||
|             self.function_to_call = defer.inlineCallbacks(orig) | ||||
|         else: | ||||
|             self.function_to_call = orig | ||||
| 
 | ||||
|         self.num_args = num_args | ||||
|         self.list_name = list_name | ||||
| 
 | ||||
|         self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] | ||||
|         self.list_pos = self.arg_names.index(self.list_name) | ||||
| 
 | ||||
|         self.cached_method_name = cached_method_name | ||||
| 
 | ||||
|         self.sentinel = object() | ||||
| 
 | ||||
|         if len(self.arg_names) < self.num_args: | ||||
|             raise Exception( | ||||
|                 "Not enough explicit positional arguments to key off of for %r." | ||||
|                 " (@cached cannot key off of *args or **kwars)" | ||||
|                 % (orig.__name__,) | ||||
|             ) | ||||
| 
 | ||||
|         if self.list_name not in self.arg_names: | ||||
|             raise Exception( | ||||
|                 "Couldn't see arguments %r for %r." | ||||
| @ -425,8 +433,7 @@ class CacheListDescriptor(object): | ||||
|                 args_to_call[self.list_name] = missing | ||||
| 
 | ||||
|                 ret_d = defer.maybeDeferred( | ||||
|                     preserve_context_over_fn, | ||||
|                     self.function_to_call, | ||||
|                     logcontext.preserve_fn(self.function_to_call), | ||||
|                     **args_to_call | ||||
|                 ) | ||||
| 
 | ||||
| @ -435,8 +442,7 @@ class CacheListDescriptor(object): | ||||
|                 # We need to create deferreds for each arg in the list so that | ||||
|                 # we can insert the new deferred into the cache. | ||||
|                 for arg in missing: | ||||
|                     with PreserveLoggingContext(): | ||||
|                         observer = ret_d.observe() | ||||
|                     observer = ret_d.observe() | ||||
|                     observer.addCallback(lambda r, arg: r.get(arg, None), arg) | ||||
| 
 | ||||
|                     observer = ObservableDeferred(observer) | ||||
| @ -463,7 +469,7 @@ class CacheListDescriptor(object): | ||||
|                     results.update(res) | ||||
|                     return results | ||||
| 
 | ||||
|                 return preserve_context_over_deferred(defer.gatherResults( | ||||
|                 return logcontext.make_deferred_yieldable(defer.gatherResults( | ||||
|                     cached_defers.values(), | ||||
|                     consumeErrors=True, | ||||
|                 ).addCallback(update_results_dict).addErrback( | ||||
| @ -487,7 +493,7 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): | ||||
|         self.cache.invalidate(self.key) | ||||
| 
 | ||||
| 
 | ||||
| def cached(max_entries=1000, num_args=1, tree=False, cache_context=False, | ||||
| def cached(max_entries=1000, num_args=None, tree=False, cache_context=False, | ||||
|            iterable=False): | ||||
|     return lambda orig: CacheDescriptor( | ||||
|         orig, | ||||
| @ -499,8 +505,8 @@ def cached(max_entries=1000, num_args=1, tree=False, cache_context=False, | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False, | ||||
|                           iterable=False): | ||||
| def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False, | ||||
|                           cache_context=False, iterable=False): | ||||
|     return lambda orig: CacheDescriptor( | ||||
|         orig, | ||||
|         max_entries=max_entries, | ||||
| @ -512,7 +518,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False): | ||||
| def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=False): | ||||
|     """Creates a descriptor that wraps a function in a `CacheListDescriptor`. | ||||
| 
 | ||||
|     Used to do batch lookups for an already created cache. A single argument | ||||
| @ -525,7 +531,8 @@ def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False) | ||||
|         cache (Cache): The underlying cache to use. | ||||
|         list_name (str): The name of the argument that is the list to use to | ||||
|             do batch lookups in the cache. | ||||
|         num_args (int): Number of arguments to use as the key in the cache. | ||||
|         num_args (int): Number of arguments to use as the key in the cache | ||||
|             (including list_name). Defaults to all named parameters. | ||||
|         inlineCallbacks (bool): Should the function be wrapped in an | ||||
|             `defer.inlineCallbacks`? | ||||
| 
 | ||||
|  | ||||
| @ -50,7 +50,7 @@ class StreamChangeCache(object): | ||||
|     def has_entity_changed(self, entity, stream_pos): | ||||
|         """Returns True if the entity may have been updated since stream_pos | ||||
|         """ | ||||
|         assert type(stream_pos) is int | ||||
|         assert type(stream_pos) is int or type(stream_pos) is long | ||||
| 
 | ||||
|         if stream_pos < self._earliest_known_stream_pos: | ||||
|             self.metrics.inc_misses() | ||||
|  | ||||
| @ -12,6 +12,16 @@ | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| """ Thread-local-alike tracking of log contexts within synapse | ||||
| 
 | ||||
| This module provides objects and utilities for tracking contexts through | ||||
| synapse code, so that log lines can include a request identifier, and so that | ||||
| CPU and database activity can be accounted for against the request that caused | ||||
| them. | ||||
| 
 | ||||
| See doc/log_contexts.rst for details on how this works. | ||||
| """ | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| import threading | ||||
| @ -300,6 +310,10 @@ def preserve_context_over_fn(fn, *args, **kwargs): | ||||
| def preserve_context_over_deferred(deferred, context=None): | ||||
|     """Given a deferred wrap it such that any callbacks added later to it will | ||||
|     be invoked with the current context. | ||||
| 
 | ||||
|     Deprecated: this almost certainly doesn't do want you want, ie make | ||||
|     the deferred follow the synapse logcontext rules: try | ||||
|     ``make_deferred_yieldable`` instead. | ||||
|     """ | ||||
|     if context is None: | ||||
|         context = LoggingContext.current_context() | ||||
| @ -309,24 +323,65 @@ def preserve_context_over_deferred(deferred, context=None): | ||||
| 
 | ||||
| 
 | ||||
| def preserve_fn(f): | ||||
|     """Ensures that function is called with correct context and that context is | ||||
|     restored after return. Useful for wrapping functions that return a deferred | ||||
|     which you don't yield on. | ||||
|     """Wraps a function, to ensure that the current context is restored after | ||||
|     return from the function, and that the sentinel context is set once the | ||||
|     deferred returned by the funtion completes. | ||||
| 
 | ||||
|     Useful for wrapping functions that return a deferred which you don't yield | ||||
|     on. | ||||
|     """ | ||||
|     def reset_context(result): | ||||
|         LoggingContext.set_current_context(LoggingContext.sentinel) | ||||
|         return result | ||||
| 
 | ||||
|     # XXX: why is this here rather than inside g? surely we want to preserve | ||||
|     # the context from the time the function was called, not when it was | ||||
|     # wrapped? | ||||
|     current = LoggingContext.current_context() | ||||
| 
 | ||||
|     def g(*args, **kwargs): | ||||
|         with PreserveLoggingContext(current): | ||||
|             res = f(*args, **kwargs) | ||||
|             if isinstance(res, defer.Deferred): | ||||
|                 return preserve_context_over_deferred( | ||||
|                     res, context=LoggingContext.sentinel | ||||
|                 ) | ||||
|             else: | ||||
|                 return res | ||||
|         res = f(*args, **kwargs) | ||||
|         if isinstance(res, defer.Deferred) and not res.called: | ||||
|             # The function will have reset the context before returning, so | ||||
|             # we need to restore it now. | ||||
|             LoggingContext.set_current_context(current) | ||||
| 
 | ||||
|             # The original context will be restored when the deferred | ||||
|             # completes, but there is nothing waiting for it, so it will | ||||
|             # get leaked into the reactor or some other function which | ||||
|             # wasn't expecting it. We therefore need to reset the context | ||||
|             # here. | ||||
|             # | ||||
|             # (If this feels asymmetric, consider it this way: we are | ||||
|             # effectively forking a new thread of execution. We are | ||||
|             # probably currently within a ``with LoggingContext()`` block, | ||||
|             # which is supposed to have a single entry and exit point. But | ||||
|             # by spawning off another deferred, we are effectively | ||||
|             # adding a new exit point.) | ||||
|             res.addBoth(reset_context) | ||||
|         return res | ||||
|     return g | ||||
| 
 | ||||
| 
 | ||||
| @defer.inlineCallbacks | ||||
| def make_deferred_yieldable(deferred): | ||||
|     """Given a deferred, make it follow the Synapse logcontext rules: | ||||
| 
 | ||||
|     If the deferred has completed (or is not actually a Deferred), essentially | ||||
|     does nothing (just returns another completed deferred with the | ||||
|     result/failure). | ||||
| 
 | ||||
|     If the deferred has not yet completed, resets the logcontext before | ||||
|     returning a deferred. Then, when the deferred completes, restores the | ||||
|     current logcontext before running callbacks/errbacks. | ||||
| 
 | ||||
|     (This is more-or-less the opposite operation to preserve_fn.) | ||||
|     """ | ||||
|     with PreserveLoggingContext(): | ||||
|         r = yield deferred | ||||
|     defer.returnValue(r) | ||||
| 
 | ||||
| 
 | ||||
| # modules to ignore in `logcontext_tracer` | ||||
| _to_ignore = [ | ||||
|     "synapse.util.logcontext", | ||||
|  | ||||
							
								
								
									
										40
									
								
								synapse/util/msisdn.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								synapse/util/msisdn.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,40 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2017 Vector Creations Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| import phonenumbers | ||||
| from synapse.api.errors import SynapseError | ||||
| 
 | ||||
| 
 | ||||
| def phone_number_to_msisdn(country, number): | ||||
|     """ | ||||
|     Takes an ISO-3166-1 2 letter country code and phone number and | ||||
|     returns an msisdn representing the canonical version of that | ||||
|     phone number. | ||||
|     Args: | ||||
|         country (str): ISO-3166-1 2 letter country code | ||||
|         number (str): Phone number in a national or international format | ||||
| 
 | ||||
|     Returns: | ||||
|         (str) The canonical form of the phone number, as an msisdn | ||||
|     Raises: | ||||
|             SynapseError if the number could not be parsed. | ||||
|     """ | ||||
|     try: | ||||
|         phoneNumber = phonenumbers.parse(number, country) | ||||
|     except phonenumbers.NumberParseException: | ||||
|         raise SynapseError(400, "Unable to parse phone number") | ||||
|     return phonenumbers.format_number( | ||||
|         phoneNumber, phonenumbers.PhoneNumberFormat.E164 | ||||
|     )[1:] | ||||
| @ -12,7 +12,7 @@ | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| import synapse.util.logcontext | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from synapse.api.errors import CodeMessageException | ||||
| @ -35,7 +35,8 @@ class NotRetryingDestination(Exception): | ||||
| 
 | ||||
| 
 | ||||
| @defer.inlineCallbacks | ||||
| def get_retry_limiter(destination, clock, store, **kwargs): | ||||
| def get_retry_limiter(destination, clock, store, ignore_backoff=False, | ||||
|                       **kwargs): | ||||
|     """For a given destination check if we have previously failed to | ||||
|     send a request there and are waiting before retrying the destination. | ||||
|     If we are not ready to retry the destination, this will raise a | ||||
| @ -43,6 +44,14 @@ def get_retry_limiter(destination, clock, store, **kwargs): | ||||
|     that will mark the destination as down if an exception is thrown (excluding | ||||
|     CodeMessageException with code < 500) | ||||
| 
 | ||||
|     Args: | ||||
|         destination (str): name of homeserver | ||||
|         clock (synapse.util.clock): timing source | ||||
|         store (synapse.storage.transactions.TransactionStore): datastore | ||||
|         ignore_backoff (bool): true to ignore the historical backoff data and | ||||
|             try the request anyway. We will still update the next | ||||
|             retry_interval on success/failure. | ||||
| 
 | ||||
|     Example usage: | ||||
| 
 | ||||
|         try: | ||||
| @ -66,7 +75,7 @@ def get_retry_limiter(destination, clock, store, **kwargs): | ||||
| 
 | ||||
|         now = int(clock.time_msec()) | ||||
| 
 | ||||
|         if retry_last_ts + retry_interval > now: | ||||
|         if not ignore_backoff and retry_last_ts + retry_interval > now: | ||||
|             raise NotRetryingDestination( | ||||
|                 retry_last_ts=retry_last_ts, | ||||
|                 retry_interval=retry_interval, | ||||
| @ -124,7 +133,13 @@ class RetryDestinationLimiter(object): | ||||
| 
 | ||||
|     def __exit__(self, exc_type, exc_val, exc_tb): | ||||
|         valid_err_code = False | ||||
|         if exc_type is not None and issubclass(exc_type, CodeMessageException): | ||||
|         if exc_type is None: | ||||
|             valid_err_code = True | ||||
|         elif not issubclass(exc_type, Exception): | ||||
|             # avoid treating exceptions which don't derive from Exception as | ||||
|             # failures; this is mostly so as not to catch defer._DefGen. | ||||
|             valid_err_code = True | ||||
|         elif issubclass(exc_type, CodeMessageException): | ||||
|             # Some error codes are perfectly fine for some APIs, whereas other | ||||
|             # APIs may expect to never received e.g. a 404. It's important to | ||||
|             # handle 404 as some remote servers will return a 404 when the HS | ||||
| @ -142,11 +157,13 @@ class RetryDestinationLimiter(object): | ||||
|             else: | ||||
|                 valid_err_code = False | ||||
| 
 | ||||
|         if exc_type is None or valid_err_code: | ||||
|         if valid_err_code: | ||||
|             # We connected successfully. | ||||
|             if not self.retry_interval: | ||||
|                 return | ||||
| 
 | ||||
|             logger.debug("Connection to %s was successful; clearing backoff", | ||||
|                          self.destination) | ||||
|             retry_last_ts = 0 | ||||
|             self.retry_interval = 0 | ||||
|         else: | ||||
| @ -160,6 +177,10 @@ class RetryDestinationLimiter(object): | ||||
|             else: | ||||
|                 self.retry_interval = self.min_retry_interval | ||||
| 
 | ||||
|             logger.debug( | ||||
|                 "Connection to %s was unsuccessful (%s(%s)); backoff now %i", | ||||
|                 self.destination, exc_type, exc_val, self.retry_interval | ||||
|             ) | ||||
|             retry_last_ts = int(self.clock.time_msec()) | ||||
| 
 | ||||
|         @defer.inlineCallbacks | ||||
| @ -173,4 +194,5 @@ class RetryDestinationLimiter(object): | ||||
|                     "Failed to store set_destination_retry_timings", | ||||
|                 ) | ||||
| 
 | ||||
|         store_retry_timings() | ||||
|         # we deliberately do this in the background. | ||||
|         synapse.util.logcontext.preserve_fn(store_retry_timings)() | ||||
|  | ||||
| @ -134,6 +134,13 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state): | ||||
|             if prev_membership not in MEMBERSHIP_PRIORITY: | ||||
|                 prev_membership = "leave" | ||||
| 
 | ||||
|             # Always allow the user to see their own leave events, otherwise | ||||
|             # they won't see the room disappear if they reject the invite | ||||
|             if membership == "leave" and ( | ||||
|                 prev_membership == "join" or prev_membership == "invite" | ||||
|             ): | ||||
|                 return True | ||||
| 
 | ||||
|             new_priority = MEMBERSHIP_PRIORITY.index(membership) | ||||
|             old_priority = MEMBERSHIP_PRIORITY.index(prev_membership) | ||||
|             if old_priority < new_priority: | ||||
|  | ||||
| @ -23,6 +23,9 @@ from tests.utils import ( | ||||
| 
 | ||||
| from synapse.api.filtering import Filter | ||||
| from synapse.events import FrozenEvent | ||||
| from synapse.api.errors import SynapseError | ||||
| 
 | ||||
| import jsonschema | ||||
| 
 | ||||
| user_localpart = "test_user" | ||||
| 
 | ||||
| @ -54,6 +57,70 @@ class FilteringTestCase(unittest.TestCase): | ||||
| 
 | ||||
|         self.datastore = hs.get_datastore() | ||||
| 
 | ||||
|     def test_errors_on_invalid_filters(self): | ||||
|         invalid_filters = [ | ||||
|             {"boom": {}}, | ||||
|             {"account_data": "Hello World"}, | ||||
|             {"event_fields": ["\\foo"]}, | ||||
|             {"room": {"timeline": {"limit": 0}, "state": {"not_bars": ["*"]}}}, | ||||
|             {"event_format": "other"}, | ||||
|             {"room": {"not_rooms": ["#foo:pik-test"]}}, | ||||
|             {"presence": {"senders": ["@bar;pik.test.com"]}} | ||||
|         ] | ||||
|         for filter in invalid_filters: | ||||
|             with self.assertRaises(SynapseError) as check_filter_error: | ||||
|                 self.filtering.check_valid_filter(filter) | ||||
|                 self.assertIsInstance(check_filter_error.exception, SynapseError) | ||||
| 
 | ||||
|     def test_valid_filters(self): | ||||
|         valid_filters = [ | ||||
|             { | ||||
|                 "room": { | ||||
|                     "timeline": {"limit": 20}, | ||||
|                     "state": {"not_types": ["m.room.member"]}, | ||||
|                     "ephemeral": {"limit": 0, "not_types": ["*"]}, | ||||
|                     "include_leave": False, | ||||
|                     "rooms": ["!dee:pik-test"], | ||||
|                     "not_rooms": ["!gee:pik-test"], | ||||
|                     "account_data": {"limit": 0, "types": ["*"]} | ||||
|                 } | ||||
|             }, | ||||
|             { | ||||
|                 "room": { | ||||
|                     "state": { | ||||
|                         "types": ["m.room.*"], | ||||
|                         "not_rooms": ["!726s6s6q:example.com"] | ||||
|                     }, | ||||
|                     "timeline": { | ||||
|                         "limit": 10, | ||||
|                         "types": ["m.room.message"], | ||||
|                         "not_rooms": ["!726s6s6q:example.com"], | ||||
|                         "not_senders": ["@spam:example.com"] | ||||
|                     }, | ||||
|                     "ephemeral": { | ||||
|                         "types": ["m.receipt", "m.typing"], | ||||
|                         "not_rooms": ["!726s6s6q:example.com"], | ||||
|                         "not_senders": ["@spam:example.com"] | ||||
|                     } | ||||
|                 }, | ||||
|                 "presence": { | ||||
|                     "types": ["m.presence"], | ||||
|                     "not_senders": ["@alice:example.com"] | ||||
|                 }, | ||||
|                 "event_format": "client", | ||||
|                 "event_fields": ["type", "content", "sender"] | ||||
|             } | ||||
|         ] | ||||
|         for filter in valid_filters: | ||||
|             try: | ||||
|                 self.filtering.check_valid_filter(filter) | ||||
|             except jsonschema.ValidationError as e: | ||||
|                 self.fail(e) | ||||
| 
 | ||||
|     def test_limits_are_applied(self): | ||||
|         # TODO | ||||
|         pass | ||||
| 
 | ||||
|     def test_definition_types_works_with_literals(self): | ||||
|         definition = { | ||||
|             "types": ["m.room.message", "org.matrix.foo.bar"] | ||||
|  | ||||
| @ -93,6 +93,7 @@ class DirectoryTestCase(unittest.TestCase): | ||||
|                 "room_alias": "#another:remote", | ||||
|             }, | ||||
|             retry_on_dns_fail=False, | ||||
|             ignore_backoff=True, | ||||
|         ) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|  | ||||
| @ -324,7 +324,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): | ||||
|         state = UserPresenceState.default(user_id) | ||||
|         state = state.copy_and_replace( | ||||
|             state=PresenceState.ONLINE, | ||||
|             last_active_ts=now, | ||||
|             last_active_ts=0, | ||||
|             last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1, | ||||
|         ) | ||||
| 
 | ||||
|  | ||||
| @ -119,7 +119,8 @@ class ProfileTestCase(unittest.TestCase): | ||||
|         self.mock_federation.make_query.assert_called_with( | ||||
|             destination="remote", | ||||
|             query_type="profile", | ||||
|             args={"user_id": "@alice:remote", "field": "displayname"} | ||||
|             args={"user_id": "@alice:remote", "field": "displayname"}, | ||||
|             ignore_backoff=True, | ||||
|         ) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|  | ||||
| @ -192,6 +192,7 @@ class TypingNotificationsTestCase(unittest.TestCase): | ||||
|                 ), | ||||
|                 json_data_callback=ANY, | ||||
|                 long_retries=True, | ||||
|                 backoff_on_404=True, | ||||
|             ), | ||||
|             defer.succeed((200, "OK")) | ||||
|         ) | ||||
| @ -263,6 +264,7 @@ class TypingNotificationsTestCase(unittest.TestCase): | ||||
|                 ), | ||||
|                 json_data_callback=ANY, | ||||
|                 long_retries=True, | ||||
|                 backoff_on_404=True, | ||||
|             ), | ||||
|             defer.succeed((200, "OK")) | ||||
|         ) | ||||
|  | ||||
| @ -68,7 +68,7 @@ class ReplicationResourceCase(unittest.TestCase): | ||||
|         code, body = yield get | ||||
|         self.assertEquals(code, 200) | ||||
|         self.assertEquals(body["events"]["field_names"], [ | ||||
|             "position", "internal", "json", "state_group" | ||||
|             "position", "event_id", "room_id", "type", "state_key", | ||||
|         ]) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|  | ||||
| @ -33,8 +33,8 @@ PATH_PREFIX = "/_matrix/client/v2_alpha" | ||||
| class FilterTestCase(unittest.TestCase): | ||||
| 
 | ||||
|     USER_ID = "@apple:test" | ||||
|     EXAMPLE_FILTER = {"type": ["m.*"]} | ||||
|     EXAMPLE_FILTER_JSON = '{"type": ["m.*"]}' | ||||
|     EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}} | ||||
|     EXAMPLE_FILTER_JSON = '{"room": {"timeline": {"types": ["m.room.message"]}}}' | ||||
|     TO_REGISTER = [filter] | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|  | ||||
| @ -89,7 +89,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): | ||||
|     @defer.inlineCallbacks | ||||
|     def test_select_one_1col(self): | ||||
|         self.mock_txn.rowcount = 1 | ||||
|         self.mock_txn.fetchall.return_value = [("Value",)] | ||||
|         self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)])) | ||||
| 
 | ||||
|         value = yield self.datastore._simple_select_one_onecol( | ||||
|             table="tablename", | ||||
| @ -136,7 +136,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): | ||||
|     @defer.inlineCallbacks | ||||
|     def test_select_list(self): | ||||
|         self.mock_txn.rowcount = 3 | ||||
|         self.mock_txn.fetchall.return_value = ((1,), (2,), (3,)) | ||||
|         self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) | ||||
|         self.mock_txn.description = ( | ||||
|             ("colA", None, None, None, None, None, None), | ||||
|         ) | ||||
|  | ||||
							
								
								
									
										53
									
								
								tests/storage/test_keys.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								tests/storage/test_keys.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,53 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2017 Vector Creations Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| import signedjson.key | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| import tests.unittest | ||||
| import tests.utils | ||||
| 
 | ||||
| 
 | ||||
| class KeyStoreTestCase(tests.unittest.TestCase): | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super(KeyStoreTestCase, self).__init__(*args, **kwargs) | ||||
|         self.store = None  # type: synapse.storage.keys.KeyStore | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def setUp(self): | ||||
|         hs = yield tests.utils.setup_test_homeserver() | ||||
|         self.store = hs.get_datastore() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def test_get_server_verify_keys(self): | ||||
|         key1 = signedjson.key.decode_verify_key_base64( | ||||
|             "ed25519", "key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw" | ||||
|         ) | ||||
|         key2 = signedjson.key.decode_verify_key_base64( | ||||
|             "ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw" | ||||
|         ) | ||||
|         yield self.store.store_server_verify_key( | ||||
|             "server1", "from_server", 0, key1 | ||||
|         ) | ||||
|         yield self.store.store_server_verify_key( | ||||
|             "server1", "from_server", 0, key2 | ||||
|         ) | ||||
| 
 | ||||
|         res = yield self.store.get_server_verify_keys( | ||||
|             "server1", ["ed25519:key1", "ed25519:key2", "ed25519:key3"]) | ||||
| 
 | ||||
|         self.assertEqual(len(res.keys()), 2) | ||||
|         self.assertEqual(res["ed25519:key1"].version, "key1") | ||||
|         self.assertEqual(res["ed25519:key2"].version, "key2") | ||||
							
								
								
									
										14
									
								
								tests/util/caches/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								tests/util/caches/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,14 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2017 Vector Creations Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
							
								
								
									
										177
									
								
								tests/util/caches/test_descriptors.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										177
									
								
								tests/util/caches/test_descriptors.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,177 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2016 OpenMarket Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| import logging | ||||
| 
 | ||||
| import mock | ||||
| from synapse.api.errors import SynapseError | ||||
| from synapse.util import async | ||||
| from synapse.util import logcontext | ||||
| from twisted.internet import defer | ||||
| from synapse.util.caches import descriptors | ||||
| from tests import unittest | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| class DescriptorTestCase(unittest.TestCase): | ||||
|     @defer.inlineCallbacks | ||||
|     def test_cache(self): | ||||
|         class Cls(object): | ||||
|             def __init__(self): | ||||
|                 self.mock = mock.Mock() | ||||
| 
 | ||||
|             @descriptors.cached() | ||||
|             def fn(self, arg1, arg2): | ||||
|                 return self.mock(arg1, arg2) | ||||
| 
 | ||||
|         obj = Cls() | ||||
| 
 | ||||
|         obj.mock.return_value = 'fish' | ||||
|         r = yield obj.fn(1, 2) | ||||
|         self.assertEqual(r, 'fish') | ||||
|         obj.mock.assert_called_once_with(1, 2) | ||||
|         obj.mock.reset_mock() | ||||
| 
 | ||||
|         # a call with different params should call the mock again | ||||
|         obj.mock.return_value = 'chips' | ||||
|         r = yield obj.fn(1, 3) | ||||
|         self.assertEqual(r, 'chips') | ||||
|         obj.mock.assert_called_once_with(1, 3) | ||||
|         obj.mock.reset_mock() | ||||
| 
 | ||||
|         # the two values should now be cached | ||||
|         r = yield obj.fn(1, 2) | ||||
|         self.assertEqual(r, 'fish') | ||||
|         r = yield obj.fn(1, 3) | ||||
|         self.assertEqual(r, 'chips') | ||||
|         obj.mock.assert_not_called() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def test_cache_num_args(self): | ||||
|         """Only the first num_args arguments should matter to the cache""" | ||||
| 
 | ||||
|         class Cls(object): | ||||
|             def __init__(self): | ||||
|                 self.mock = mock.Mock() | ||||
| 
 | ||||
|             @descriptors.cached(num_args=1) | ||||
|             def fn(self, arg1, arg2): | ||||
|                 return self.mock(arg1, arg2) | ||||
| 
 | ||||
|         obj = Cls() | ||||
|         obj.mock.return_value = 'fish' | ||||
|         r = yield obj.fn(1, 2) | ||||
|         self.assertEqual(r, 'fish') | ||||
|         obj.mock.assert_called_once_with(1, 2) | ||||
|         obj.mock.reset_mock() | ||||
| 
 | ||||
|         # a call with different params should call the mock again | ||||
|         obj.mock.return_value = 'chips' | ||||
|         r = yield obj.fn(2, 3) | ||||
|         self.assertEqual(r, 'chips') | ||||
|         obj.mock.assert_called_once_with(2, 3) | ||||
|         obj.mock.reset_mock() | ||||
| 
 | ||||
|         # the two values should now be cached; we should be able to vary | ||||
|         # the second argument and still get the cached result. | ||||
|         r = yield obj.fn(1, 4) | ||||
|         self.assertEqual(r, 'fish') | ||||
|         r = yield obj.fn(2, 5) | ||||
|         self.assertEqual(r, 'chips') | ||||
|         obj.mock.assert_not_called() | ||||
| 
 | ||||
|     def test_cache_logcontexts(self): | ||||
|         """Check that logcontexts are set and restored correctly when | ||||
|         using the cache.""" | ||||
| 
 | ||||
|         complete_lookup = defer.Deferred() | ||||
| 
 | ||||
|         class Cls(object): | ||||
|             @descriptors.cached() | ||||
|             def fn(self, arg1): | ||||
|                 @defer.inlineCallbacks | ||||
|                 def inner_fn(): | ||||
|                     with logcontext.PreserveLoggingContext(): | ||||
|                         yield complete_lookup | ||||
|                     defer.returnValue(1) | ||||
| 
 | ||||
|                 return inner_fn() | ||||
| 
 | ||||
|         @defer.inlineCallbacks | ||||
|         def do_lookup(): | ||||
|             with logcontext.LoggingContext() as c1: | ||||
|                 c1.name = "c1" | ||||
|                 r = yield obj.fn(1) | ||||
|                 self.assertEqual(logcontext.LoggingContext.current_context(), | ||||
|                                  c1) | ||||
|             defer.returnValue(r) | ||||
| 
 | ||||
|         def check_result(r): | ||||
|             self.assertEqual(r, 1) | ||||
| 
 | ||||
|         obj = Cls() | ||||
| 
 | ||||
|         # set off a deferred which will do a cache lookup | ||||
|         d1 = do_lookup() | ||||
|         self.assertEqual(logcontext.LoggingContext.current_context(), | ||||
|                          logcontext.LoggingContext.sentinel) | ||||
|         d1.addCallback(check_result) | ||||
| 
 | ||||
|         # and another | ||||
|         d2 = do_lookup() | ||||
|         self.assertEqual(logcontext.LoggingContext.current_context(), | ||||
|                          logcontext.LoggingContext.sentinel) | ||||
|         d2.addCallback(check_result) | ||||
| 
 | ||||
|         # let the lookup complete | ||||
|         complete_lookup.callback(None) | ||||
| 
 | ||||
|         return defer.gatherResults([d1, d2]) | ||||
| 
 | ||||
|     def test_cache_logcontexts_with_exception(self): | ||||
|         """Check that the cache sets and restores logcontexts correctly when | ||||
|         the lookup function throws an exception""" | ||||
| 
 | ||||
|         class Cls(object): | ||||
|             @descriptors.cached() | ||||
|             def fn(self, arg1): | ||||
|                 @defer.inlineCallbacks | ||||
|                 def inner_fn(): | ||||
|                     yield async.run_on_reactor() | ||||
|                     raise SynapseError(400, "blah") | ||||
| 
 | ||||
|                 return inner_fn() | ||||
| 
 | ||||
|         @defer.inlineCallbacks | ||||
|         def do_lookup(): | ||||
|             with logcontext.LoggingContext() as c1: | ||||
|                 c1.name = "c1" | ||||
|                 try: | ||||
|                     yield obj.fn(1) | ||||
|                     self.fail("No exception thrown") | ||||
|                 except SynapseError: | ||||
|                     pass | ||||
| 
 | ||||
|                 self.assertEqual(logcontext.LoggingContext.current_context(), | ||||
|                                  c1) | ||||
| 
 | ||||
|         obj = Cls() | ||||
| 
 | ||||
|         # set off a deferred which will do a cache lookup | ||||
|         d1 = do_lookup() | ||||
|         self.assertEqual(logcontext.LoggingContext.current_context(), | ||||
|                          logcontext.LoggingContext.sentinel) | ||||
| 
 | ||||
|         return d1 | ||||
							
								
								
									
										33
									
								
								tests/util/test_clock.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								tests/util/test_clock.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,33 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2017 Vector Creations Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| from synapse import util | ||||
| from twisted.internet import defer | ||||
| from tests import unittest | ||||
| 
 | ||||
| 
 | ||||
| class ClockTestCase(unittest.TestCase): | ||||
|     @defer.inlineCallbacks | ||||
|     def test_time_bound_deferred(self): | ||||
|         # just a deferred which never resolves | ||||
|         slow_deferred = defer.Deferred() | ||||
| 
 | ||||
|         clock = util.Clock() | ||||
|         time_bound = clock.time_bound_deferred(slow_deferred, 0.001) | ||||
| 
 | ||||
|         try: | ||||
|             yield time_bound | ||||
|             self.fail("Expected timedout error, but got nothing") | ||||
|         except util.DeferredTimedOutError: | ||||
|             pass | ||||
| @ -1,8 +1,10 @@ | ||||
| import twisted.python.failure | ||||
| from twisted.internet import defer | ||||
| from twisted.internet import reactor | ||||
| from .. import unittest | ||||
| 
 | ||||
| from synapse.util.async import sleep | ||||
| from synapse.util import logcontext | ||||
| from synapse.util.logcontext import LoggingContext | ||||
| 
 | ||||
| 
 | ||||
| @ -33,3 +35,62 @@ class LoggingContextTestCase(unittest.TestCase): | ||||
|             context_one.test_key = "one" | ||||
|             yield sleep(0) | ||||
|             self._check_test_key("one") | ||||
| 
 | ||||
|     def _test_preserve_fn(self, function): | ||||
|         sentinel_context = LoggingContext.current_context() | ||||
| 
 | ||||
|         callback_completed = [False] | ||||
| 
 | ||||
|         @defer.inlineCallbacks | ||||
|         def cb(): | ||||
|             context_one.test_key = "one" | ||||
|             yield function() | ||||
|             self._check_test_key("one") | ||||
| 
 | ||||
|             callback_completed[0] = True | ||||
| 
 | ||||
|         with LoggingContext() as context_one: | ||||
|             context_one.test_key = "one" | ||||
| 
 | ||||
|             # fire off function, but don't wait on it. | ||||
|             logcontext.preserve_fn(cb)() | ||||
| 
 | ||||
|             self._check_test_key("one") | ||||
| 
 | ||||
|         # now wait for the function under test to have run, and check that | ||||
|         # the logcontext is left in a sane state. | ||||
|         d2 = defer.Deferred() | ||||
| 
 | ||||
|         def check_logcontext(): | ||||
|             if not callback_completed[0]: | ||||
|                 reactor.callLater(0.01, check_logcontext) | ||||
|                 return | ||||
| 
 | ||||
|             # make sure that the context was reset before it got thrown back | ||||
|             # into the reactor | ||||
|             try: | ||||
|                 self.assertIs(LoggingContext.current_context(), | ||||
|                               sentinel_context) | ||||
|                 d2.callback(None) | ||||
|             except BaseException: | ||||
|                 d2.errback(twisted.python.failure.Failure()) | ||||
| 
 | ||||
|         reactor.callLater(0.01, check_logcontext) | ||||
| 
 | ||||
|         # test is done once d2 finishes | ||||
|         return d2 | ||||
| 
 | ||||
|     def test_preserve_fn_with_blocking_fn(self): | ||||
|         @defer.inlineCallbacks | ||||
|         def blocking_function(): | ||||
|             yield sleep(0) | ||||
| 
 | ||||
|         return self._test_preserve_fn(blocking_function) | ||||
| 
 | ||||
|     def test_preserve_fn_with_non_blocking_fn(self): | ||||
|         @defer.inlineCallbacks | ||||
|         def nonblocking_function(): | ||||
|             with logcontext.PreserveLoggingContext(): | ||||
|                 yield defer.succeed(None) | ||||
| 
 | ||||
|         return self._test_preserve_fn(nonblocking_function) | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user