| # Copyright 2015 The Chromium Authors. All rights reserved. | 
 | # Use of this source code is governed by a BSD-style license that can be | 
 | # found in the LICENSE file. | 
 |  | 
 | """Defines the task controller library.""" | 
 |  | 
 | import argparse | 
 | import datetime | 
 | import logging | 
 | import os | 
 | import socket | 
 | import subprocess | 
 | import sys | 
 | import threading | 
 |  | 
 | from legion.lib import common_lib | 
 | from legion.lib import process | 
 | from legion.lib.rpc import rpc_server | 
 | from legion.lib.rpc import jsonrpclib | 
 |  | 
 | ISOLATE_PY = os.path.join(common_lib.SWARMING_DIR, 'isolate.py') | 
 | SWARMING_PY = os.path.join(common_lib.SWARMING_DIR, 'swarming.py') | 
 |  | 
 |  | 
 | class Error(Exception): | 
 |   pass | 
 |  | 
 |  | 
 | class ConnectionTimeoutError(Error): | 
 |   pass | 
 |  | 
 |  | 
 | class TaskController(object): | 
 |   """Provisions, configures, and controls a task machine. | 
 |  | 
 |   This class is an abstraction of a physical task machine. It provides an | 
 |   end to end API for controlling a task machine. Operations on the task machine | 
 |   are performed using the instance's "rpc" property. A simple end to end | 
 |   scenario is as follows: | 
 |  | 
 |   task = TaskController(...) | 
 |   task.Create() | 
 |   task.WaitForConnection() | 
 |   proc = task.rpc.subprocess.Popen(['ls']) | 
 |   print task.rpc.subprocess.GetStdout(proc) | 
 |   task.Release() | 
 |   """ | 
 |  | 
 |   _task_count = 0 | 
 |   _tasks = [] | 
 |  | 
 |   def __init__(self, isolated_hash, dimensions, reg_server_port, priority=100, | 
 |                idle_timeout_secs=common_lib.DEFAULT_TIMEOUT_SECS, | 
 |                connection_timeout_secs=common_lib.DEFAULT_TIMEOUT_SECS, | 
 |                verbosity='ERROR', name=None, run_id=None): | 
 |     assert isinstance(dimensions, dict) | 
 |     type(self)._tasks.append(self) | 
 |     type(self)._task_count += 1 | 
 |     self.verbosity = verbosity | 
 |     self._name = name or 'Task%d' % type(self)._task_count | 
 |     self._priority = priority | 
 |     self._isolated_hash = isolated_hash | 
 |     self._idle_timeout_secs = idle_timeout_secs | 
 |     self._dimensions = dimensions | 
 |     self._connect_event = threading.Event() | 
 |     self._connected = False | 
 |     self._ip_address = None | 
 |     self._reg_server_port = reg_server_port | 
 |     self._otp = self._CreateOTP() | 
 |     self._rpc = None | 
 |     self._output_dir = None | 
 |     self._platform = None | 
 |     self._executable = None | 
 |     self._task_rpc_port = None | 
 |  | 
 |     run_id = run_id or datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') | 
 |     self._task_name = '%s/%s/%s' % ( | 
 |         os.path.splitext(sys.argv[0])[0], self._name, run_id) | 
 |  | 
 |     parser = argparse.ArgumentParser() | 
 |     parser.add_argument('--isolate-server') | 
 |     parser.add_argument('--swarming-server') | 
 |     parser.add_argument('--task-connection-timeout-secs', | 
 |                         default=common_lib.DEFAULT_TIMEOUT_SECS) | 
 |     args, _ = parser.parse_known_args() | 
 |  | 
 |     self._isolate_server = args.isolate_server | 
 |     self._swarming_server = args.swarming_server | 
 |     self._connection_timeout_secs = (connection_timeout_secs or | 
 |                                     args.task_connection_timeout_secs) | 
 |  | 
 |     # Register for the shutdown event | 
 |     common_lib.OnShutdown += self.Release | 
 |  | 
 |   @property | 
 |   def name(self): | 
 |     return self._name | 
 |  | 
 |   @property | 
 |   def otp(self): | 
 |     return self._otp | 
 |  | 
 |   @property | 
 |   def connected(self): | 
 |     return self._connected | 
 |  | 
 |   @property | 
 |   def connect_event(self): | 
 |     return self._connect_event | 
 |  | 
 |   @property | 
 |   def rpc(self): | 
 |     return self._rpc | 
 |  | 
 |   @property | 
 |   def verbosity(self): | 
 |     return self._verbosity | 
 |  | 
 |   @verbosity.setter | 
 |   def verbosity(self, level): | 
 |     """Sets the verbosity level as a string. | 
 |  | 
 |     Either a string ('INFO', 'DEBUG', etc) or a logging level (logging.INFO, | 
 |     logging.DEBUG, etc) is allowed. | 
 |     """ | 
 |     assert isinstance(level, (str, int)) | 
 |     if isinstance(level, int): | 
 |       level = logging.getLevelName(level) | 
 |     self._verbosity = level  #pylint: disable=attribute-defined-outside-init | 
 |  | 
 |   @property | 
 |   def output_dir(self): | 
 |     if not self._output_dir: | 
 |       self._output_dir = self.rpc.GetOutputDir() | 
 |     return self._output_dir | 
 |  | 
 |   @property | 
 |   def platform(self): | 
 |     if not self._platform: | 
 |       self._platform = self._rpc.GetPlatform() | 
 |     return self._platform | 
 |  | 
 |   @property | 
 |   def ip_address(self): | 
 |     if not self._ip_address: | 
 |       self._ip_address = self.rpc.GetIpAddress() | 
 |     return self._ip_address | 
 |  | 
 |   @property | 
 |   def executable(self): | 
 |     if not self._executable: | 
 |       self._executable = self.rpc.GetExecutable() | 
 |     return self._executable | 
 |  | 
 |   @classmethod | 
 |   def ReleaseAllTasks(cls): | 
 |     for task in cls._tasks: | 
 |       task.Release() | 
 |  | 
 |   def Process(self, cmd, *args, **kwargs): | 
 |     return process.ControllerProcessWrapper(self.rpc, cmd, *args, **kwargs) | 
 |  | 
 |   def _CreateOTP(self): | 
 |     """Creates the OTP.""" | 
 |     controller_name = socket.gethostname() | 
 |     test_name = os.path.basename(sys.argv[0]) | 
 |     creation_time = datetime.datetime.utcnow() | 
 |     otp = 'task:%s controller:%s port: %d test:%s creation:%s' % ( | 
 |         self._name, controller_name, self._reg_server_port, test_name, | 
 |         creation_time) | 
 |     return otp | 
 |  | 
 |   def Create(self): | 
 |     """Creates the task machine.""" | 
 |     logging.info('Creating %s', self.name) | 
 |     self._connect_event.clear() | 
 |     self._ExecuteSwarming() | 
 |  | 
 |   def WaitForConnection(self): | 
 |     """Waits for the task machine to connect. | 
 |  | 
 |     Raises: | 
 |       ConnectionTimeoutError if the task doesn't connect in time. | 
 |     """ | 
 |     logging.info('Waiting for %s to connect with a timeout of %d seconds', | 
 |                  self._name, self._connection_timeout_secs) | 
 |     self._connect_event.wait(self._connection_timeout_secs) | 
 |     if not self._connect_event.is_set(): | 
 |       raise ConnectionTimeoutError('%s failed to connect' % self.name) | 
 |  | 
 |   def Release(self): | 
 |     """Quits the task's RPC server so it can release the machine.""" | 
 |     if self._rpc is not None and self._connected: | 
 |       logging.info('Copying output-dir files to controller') | 
 |       self.RetrieveOutputFiles() | 
 |       logging.info('Releasing %s', self._name) | 
 |       try: | 
 |         self._rpc.Quit() | 
 |       except (socket.error, jsonrpclib.Fault): | 
 |         logging.error('Unable to connect to %s to call Quit', self.name) | 
 |       self._rpc = None | 
 |       self._connected = False | 
 |  | 
 |   def _ExecuteSwarming(self): | 
 |     """Executes swarming.py.""" | 
 |     cmd = [ | 
 |         'python', | 
 |         SWARMING_PY, | 
 |         'trigger', | 
 |         self._isolated_hash, | 
 |         '--priority', str(self._priority), | 
 |         '--task-name', self._task_name, | 
 |         ] | 
 |  | 
 |     if self._isolate_server: | 
 |       cmd.extend(['--isolate-server', self._isolate_server]) | 
 |     if self._swarming_server: | 
 |       cmd.extend(['--swarming', self._swarming_server]) | 
 |     for key, value in self._dimensions.iteritems(): | 
 |       cmd.extend(['--dimension', key, value]) | 
 |  | 
 |     cmd.extend([ | 
 |         '--', | 
 |         '--controller', common_lib.MY_IP, | 
 |         '--controller-port', str(self._reg_server_port), | 
 |         '--otp', self._otp, | 
 |         '--verbosity', self._verbosity, | 
 |         '--idle-timeout', str(self._idle_timeout_secs), | 
 |         '--output-dir', '${ISOLATED_OUTDIR}' | 
 |         ]) | 
 |  | 
 |     self._ExecuteProcess(cmd) | 
 |  | 
 |   def _ExecuteProcess(self, cmd): | 
 |     """Executes a process, waits for it to complete, and checks for success.""" | 
 |     logging.debug('Running %s', ' '.join(cmd)) | 
 |     p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | 
 |     _, stderr = p.communicate() | 
 |     if p.returncode != 0: | 
 |       raise Error(stderr) | 
 |  | 
 |   def OnConnect(self, ip_address, rpc_port): | 
 |     """Receives task ip address and port on connection.""" | 
 |     self._ip_address = ip_address | 
 |     self._task_rpc_port = rpc_port | 
 |     self._connected = True | 
 |     self._rpc = rpc_server.RpcServer.Connect(self._ip_address, | 
 |                                              self._task_rpc_port) | 
 |     logging.info('%s connected from %s:%s', self._name, ip_address, | 
 |                  self._task_rpc_port) | 
 |     self._connect_event.set() | 
 |  | 
 |   def RetrieveOutputFiles(self): | 
 |     """Retrieves all files in the output-dir.""" | 
 |     files = self.rpc.ListDir(self.output_dir) | 
 |     for fname in files: | 
 |       remote_path = self.rpc.PathJoin(self.output_dir, fname) | 
 |       local_name = os.path.join(common_lib.GetOutputDir(), | 
 |                                 '%s.%s' % (self.name, fname)) | 
 |       contents = self.rpc.ReadFile(remote_path) | 
 |       with open(local_name, 'wb+') as fh: | 
 |         fh.write(contents) |