diff --git a/bin/cros_au_test_harness.py b/bin/cros_au_test_harness.py index 2449587df8..0e484947ab 100755 --- a/bin/cros_au_test_harness.py +++ b/bin/cros_au_test_harness.py @@ -8,6 +8,8 @@ import optparse import os import re import sys +import thread +import time import unittest import urllib @@ -19,6 +21,8 @@ from cros_build_lib import RunCommand from cros_build_lib import RunCommandCaptureOutput from cros_build_lib import Warning +import cros_test_proxy + # VM Constants. _FULL_VDISK_SIZE = 6072 _FULL_STATEFULFS_SIZE = 3074 @@ -90,20 +94,21 @@ class AUTest(object): else: self._UpdateImageReportError(image) - def _UpdateImageReportError(self, image_path, stateful_change='old'): + def _UpdateImageReportError(self, image_path, stateful_change='old', + proxy_port=None): """Calls UpdateImage and reports any error to the console. Still throws the exception. """ try: - self.UpdateImage(image_path, stateful_change) + self.UpdateImage(image_path, stateful_change, proxy_port) except UpdateException as err: # If the update fails, print it out Warning(err.stdout) raise def _AttemptUpdateWithPayloadExpectedFailure(self, payload, expected_msg): - # This update is expected to fail... + """Attempt a payload update, expect it to fail with expected log""" try: self.UpdateUsingPayload(payload) except UpdateException as err: @@ -115,11 +120,32 @@ class AUTest(object): Warning(err.stdout) self.fail('We managed to update when failure was expected') + def _AttemptUpdateWithFilter(self, filter): + """Update through a proxy, with a specified filter, and expect success.""" + + self.PrepareBase(target_image_path) + + # The devserver runs at port 8080 by default. We assume that here, and + # start our proxy at 8081. We then tell our update tools to have the + # client connect to 8081 instead of 8080. + proxy_port = 8081 + proxy = cros_test_proxy.CrosTestProxy(port_in=proxy_port, + address_out='127.0.0.1', + port_out=8080, + filter=filter) + proxy.serve_forever_in_thread() + + # This update is expected to fail... + try: + self._UpdateImageReportError(target_image_path, proxy_port=proxy_port) + finally: + proxy.shutdown() + def PrepareBase(self, image_path): """Prepares target with base_image_path.""" pass - def UpdateImage(self, image_path, stateful_change='old'): + def UpdateImage(self, image_path, stateful_change='old', proxy_port=None): """Updates target with the image given by the image_path. Args: @@ -129,15 +155,22 @@ class AUTest(object): 'old': Don't modify stateful partition. Just update normally. 'clean': Uses clobber-state to wipe the stateful partition with the exception of code needed for ssh. + proxy_port: Port to have the client connect to. For use with + CrosTestProxy. """ pass - def UpdateUsingPayload(self, update_path, stateful_change='old'): + def UpdateUsingPayload(self, + update_path, + stateful_change='old', + proxy_port=None): """Updates target with the pre-generated update stored in update_path Args: update_path: Path to the image to update with. This directory should - contain both update.gz, and stateful.image.gz + contain both update.gz, and stateful.image.gz + proxy_port: Port to have the client connect to. For use with + CrosTestProxy. """ pass @@ -187,7 +220,7 @@ class AUTest(object): """ # Just make sure some tests pass on original image. Some old images # don't pass many tests. - self.PrepareBase(image_path=base_image_path) + self.PrepareBase(base_image_path) # TODO(sosa): move to 100% once we start testing using the autotest paired # with the dev channel. percent_passed = self.VerifyImage(10) @@ -210,7 +243,7 @@ class AUTest(object): """ # Just make sure some tests pass on original image. Some old images # don't pass many tests. - self.PrepareBase(image_path=base_image_path) + self.PrepareBase(base_image_path) # TODO(sosa): move to 100% once we start testing using the autotest paired # with the dev channel. percent_passed = self.VerifyImage(10) @@ -228,7 +261,7 @@ class AUTest(object): def testPartialUpdate(self): """Tests what happens if we attempt to update with a truncated payload.""" # Preload with the version we are trying to test. - self.PrepareBase(image_path=target_image_path) + self.PrepareBase(target_image_path) # Image can be updated at: # ~chrome-eng/chromeos/localmirror/autest-images @@ -245,7 +278,7 @@ class AUTest(object): def testCorruptedUpdate(self): """Tests what happens if we attempt to update with a corrupted payload.""" # Preload with the version we are trying to test. - self.PrepareBase(image_path=target_image_path) + self.PrepareBase(target_image_path) # Image can be updated at: # ~chrome-eng/chromeos/localmirror/autest-images @@ -260,6 +293,71 @@ class AUTest(object): expected_msg = 'zlib inflate() error:-3' self._AttemptUpdateWithPayloadExpectedFailure(payload, expected_msg) + def testInterruptedUpdate(self): + """Tests what happens if we interrupt payload delivery 3 times.""" + + class InterruptionFilter(cros_test_proxy.Filter): + """This filter causes the proxy to interrupt the download 3 times + + It does this by closing the first three connections to transfer + 2M total in the outbound connection after they transfer the + 2M. + """ + def __init__(self): + """Defines variable shared across all connections""" + self.close_count = 0 + + def setup(self): + """Called once at the start of each connection.""" + self.data_size = 0 + + def OutBound(self, data): + """Called once per packet for outgoing data. + + The first three connections transferring more than 2M + outbound will be closed. + """ + if self.close_count < 3: + if self.data_size > (2 * 1024 * 1024): + self.close_count += 1 + return None + + self.data_size += len(data) + return data + + self._AttemptUpdateWithFilter(InterruptionFilter()) + + def testDelayedUpdate(self): + """Tests what happens if some data is delayed during update delivery""" + + class DelayedFilter(cros_test_proxy.Filter): + """Causes intermittent delays in data transmission. + + It does this by inserting 3 20 second delays when transmitting + data after 2M has been sent. + """ + def setup(self): + """Called once at the start of each connection.""" + self.data_size = 0 + self.delay_count = 0 + + def OutBound(self, data): + """Called once per packet for outgoing data. + + The first three packets after we reach 2M transferred + are delayed by 20 seconds. + """ + if self.delay_count < 3: + if self.data_size > (2 * 1024 * 1024): + self.delay_count += 1 + time.sleep(20) + + self.data_size += len(data) + return data + + + self._AttemptUpdateWithFilter(DelayedFilter()) + class RealAUTest(unittest.TestCase, AUTest): """Test harness for updating real images.""" @@ -270,7 +368,7 @@ class RealAUTest(unittest.TestCase, AUTest): """Auto-update to base image to prepare for test.""" self._UpdateImageReportError(image_path) - def UpdateImage(self, image_path, stateful_change='old'): + def UpdateImage(self, image_path, stateful_change='old', proxy_port=None): """Updates a remote image using image_to_live.sh.""" stateful_change_flag = self.GetStatefulChangeFlag(stateful_change) cmd = ['%s/image_to_live.sh' % self.crosutils, @@ -281,6 +379,9 @@ class RealAUTest(unittest.TestCase, AUTest): '--src_image=%s' % self.source_image ] + if proxy_port: + cmd.append('--proxy_port=%s' % proxy_port) + if self.verbose: try: RunCommand(cmd) @@ -291,7 +392,10 @@ class RealAUTest(unittest.TestCase, AUTest): if code != 0: raise UpdateException(code, stdout) - def UpdateUsingPayload(self, update_path, stateful_change='old'): + def UpdateUsingPayload(self, + update_path, + stateful_change='old', + proxy_port=None): """Updates a remote image using image_to_live.sh.""" stateful_change_flag = self.GetStatefulChangeFlag(stateful_change) cmd = ['%s/image_to_live.sh' % self.crosutils, @@ -301,6 +405,9 @@ class RealAUTest(unittest.TestCase, AUTest): '--verify', ] + if proxy_port: + cmd.append('--proxy_port=%s' % proxy_port) + if self.verbose: try: RunCommand(cmd) @@ -366,7 +473,7 @@ class VirtualAUTest(unittest.TestCase, AUTest): self.assertTrue(os.path.exists(self.vm_image_path)) - def UpdateImage(self, image_path, stateful_change='old'): + def UpdateImage(self, image_path, stateful_change='old', proxy_port=None): """Updates VM image with image_path.""" stateful_change_flag = self.GetStatefulChangeFlag(stateful_change) if self.source_image == base_image_path: @@ -382,6 +489,10 @@ class VirtualAUTest(unittest.TestCase, AUTest): stateful_change_flag, '--src_image=%s' % self.source_image, ] + + if proxy_port: + cmd.append('--proxy_port=%s' % proxy_port) + if self.verbose: try: RunCommand(cmd) @@ -392,7 +503,10 @@ class VirtualAUTest(unittest.TestCase, AUTest): if code != 0: raise UpdateException(code, stdout) - def UpdateUsingPayload(self, update_path, stateful_change='old'): + def UpdateUsingPayload(self, + update_path, + stateful_change='old', + proxy_port=None): """Updates a remote image using image_to_live.sh.""" stateful_change_flag = self.GetStatefulChangeFlag(stateful_change) if self.source_image == base_image_path: @@ -409,6 +523,9 @@ class VirtualAUTest(unittest.TestCase, AUTest): '--src_image=%s' % self.source_image, ] + if proxy_port: + cmd.append('--proxy_port=%s' % proxy_port) + if self.verbose: try: RunCommand(cmd) diff --git a/image_to_live.sh b/image_to_live.sh index 3988470254..545e428b92 100755 --- a/image_to_live.sh +++ b/image_to_live.sh @@ -38,6 +38,8 @@ DEFINE_string image "" \ "Update with this image path that is in this source checkout." i DEFINE_string payload "" \ "Update with this update payload, ignoring specified images." +DEFINE_string proxy_port "" \ + "Have the client request from this proxy instead of devserver." DEFINE_string src_image "" \ "Create a delta update by passing in the image on the remote machine." DEFINE_boolean update_stateful ${FLAGS_TRUE} \ @@ -139,6 +141,11 @@ function start_dev_server { --payload $(reinterpret_path_for_chroot ${FLAGS_payload})" fi + if [ -n "${FLAGS_proxy_port}" ]; then + devserver_flags="${devserver_flags} \ + --proxy_port ${FLAGS_proxy_port}" + fi + [ ${FLAGS_for_vm} -eq ${FLAGS_TRUE} ] && \ devserver_flags="${devserver_flags} --for_vm" @@ -209,9 +216,15 @@ function get_update_args { function get_devserver_url { local devserver_url="" + local port=${FLAGS_devserver_port} + + if [[ -n ${FLAGS_proxy_port} ]]; then + port=${FLAGS_proxy_port} + fi + if [ ${FLAGS_ignore_hostname} -eq ${FLAGS_TRUE} ]; then if [ -z ${FLAGS_update_url} ]; then - devserver_url="http://$(get_hostname):${FLAGS_devserver_port}/update" + devserver_url="http://$(get_hostname):${port}/update" else devserver_url="${FLAGS_update_url}" fi diff --git a/lib/cros_test_proxy.py b/lib/cros_test_proxy.py new file mode 100755 index 0000000000..21709829c3 --- /dev/null +++ b/lib/cros_test_proxy.py @@ -0,0 +1,113 @@ +# Copyright (c) 2010 The Chromium OS Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +import select +import socket +import SocketServer +import threading + +class Filter(object): + """Base class for data filters. + + Pass subclass of this to CrosTestProxy which will perform whatever + connection manipulation you prefer. + """ + + def setup(self): + """This setup method is called once per connection.""" + pass + + def InBound(self, data): + """This method is called once per packet of incoming data. + + The value returned is what is sent through the proxy. If + None is returned, the connection will be closed. + """ + return data + + def OutBound(self, data): + """This method is called once per packet of outgoing data. + + The value returned is what is sent through the proxy. If + None is returned, the connection will be closed. + """ + return data + + +class CrosTestProxy(SocketServer.ThreadingMixIn, SocketServer.TCPServer): + """A transparent proxy for simulating network errors""" + + class _Handler(SocketServer.BaseRequestHandler): + """Proxy connection handler that passes data though a filter""" + + def setup(self): + """Setup is called once for each connection proxied.""" + self.server.filter.setup() + + def handle(self): + """Handles each incoming connection. + + Opens a new connection to the port we are proxing to, then + passes each packet along in both directions after passing + them through the filter object passed in. + """ + # Open outgoing socket + s_in = self.request + s_out = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s_out.connect((self.server.address_out, self.server.port_out)) + + while True: + rlist, wlist, xlist = select.select([s_in, s_out], [], []) + + if s_in in rlist: + data = s_in.recv(1024) + data = self.server.filter.InBound(data) + if not data: break + try: + # If there is any error sending data, close both connections. + s_out.sendall(data) + except socket.error: + break + + if s_out in rlist: + data = s_out.recv(1024) + data = self.server.filter.OutBound(data) + if not data: break + try: + # If there is any error sending data, close both connections. + s_in.sendall(data) + except socket.error: + break + + s_in.close() + s_out.close() + + def __init__(self, + filter, + port_in=8081, + address_out='127.0.0.1', port_out=8080): + """Configures the proxy object. + + Args: + filter: An instance of a subclass of Filter. + port_in: Port on which to listen for incoming connections. + address_out: Address to which outgoing connections will go. + address_port: Port to which outgoing connections will go. + """ + self.port_in = port_in + self.address_out = address_out + self.port_out = port_out + self.filter = filter + + SocketServer.TCPServer.__init__(self, + ('', port_in), + self._Handler) + + def serve_forever_in_thread(self): + """Helper method to start the server in a new background thread.""" + server_thread = threading.Thread(target=self.serve_forever) + server_thread.setDaemon(True) + server_thread.start() + + return server_thread