Source code for ariane.app.tas.server._server

from ....lib.communication import Server

from abc import abstractmethod
from datetime import datetime
from numpy.random import RandomState
from threading import Thread


PORT_DEFAULT = 11657


[docs] class TASServer(Server): """Base class for servers in TAS application. It implements processing actions with corresponding data and accordingly calls template methods of derived classes. """ def __init__(self, *, version, method, config, random_seed=None, port=PORT_DEFAULT): super().__init__('ARIANE', port, config=config) self.version = version self.method = method # random state for reproducibility self.random_seed = random_seed self.random_state = RandomState(random_seed) self.busy = False
[docs] def process(self, action: str, data: dict): """Process `action` with corresponding `data` and accordingly call template methods of derived classes. """ ret_data = {} if self.busy: ret_data['busy'] = True else: if action == 'ping': ret_data['method'] = self.method ret_data['version'] = self.version elif action == 'reset': # reset server instance for fresh use self.start_time = datetime.now() self.end_time = None self.logger.info( f"Start time: {self.start_time.isoformat(' ', timespec='seconds')}") mode = data['mode'] axes = data['axes'] offset = data['offset'] limits = data['limits'] # set background level. If not specified, set it to `None`. level_backgr = data['level_backgr'] if 'level_backgr' in data else None # set intensity threshold. If not specified, set it to `None`. thresh_intens = data['thresh_intens'] if 'thresh_intens' in data else None # set maximum travel cost for normalizing travel costs. # if not specified, set it to `1`. travel_cost_max = data['travel_cost_max'] \ if 'travel_cost_max' in data and data['travel_cost_max'] is not None else 1 scenario_name = data['scenario_name'] if 'scenario_name' in data else None assert len(axes) == len(limits) assert level_backgr is None or level_backgr >= 0 assert thresh_intens is None or thresh_intens >= 0 assert travel_cost_max > 0 self.reset(mode, axes, offset, limits, level_backgr, thresh_intens, travel_cost_max, scenario_name) elif action == 'result': # process new results, if necessary in a "busy thread" locs = data['locs'] counts = data['counts'] matrices_ellipses = data['matrices_ellipses'] travel_time = data['travel_time'] counting_time = data['counting_time'] travel_cost_grid = data['travel_cost_grid'] if 'travel_cost_grid' in data else None travel_cost_values = data['travel_cost_values'] if 'travel_cost_values' in data else None assert len(locs) == len(counts) == len(matrices_ellipses) args = locs, counts, matrices_ellipses, travel_time, counting_time, \ travel_cost_grid, travel_cost_values if not self.method_initialized or self.param_optim_stopped: self.result(*args) self.save_state(id='1_interm') else: ResultThread(self, *args).start() elif action == 'next_loc': # compute next location based on current results or initialize method in a # "busy thread" if not done yet if not self.method_initialized: if self.method_initializable: InitThread(self).start() ret_data['busy'] = True else: raise Exception('method is not initializable') else: ret_data['loc'], ret_data['stop'] = self.next_loc() elif action == 'heuris_experi_param': # heuristically compute experiment parameters based on current results ret_data['level_backgr'], ret_data['thresh_intens'] = self.heuris_experi_param() elif action == 'problem_locs': # process problematic locations locs = data['locs'] matrices_ellipses = data['matrices_ellipses'] self.problem_locs(locs, matrices_ellipses) elif action == 'state_internal': if not self.method_initialized: raise RuntimeError("server is not initialized") # provide internal state of server ret_data = {**ret_data, **self.state_internal(data)} # merge dictionaries elif action == 'stop': # finalize # log final information self.end_time = datetime.now() self.logger.info(f"End time: {self.end_time.isoformat(' ', timespec='seconds')}") self.logger.info( f"Total hours exceeded: {(self.end_time-self.start_time).total_seconds() / 3600.}") self.save_state(id='2_final') self.stop() else: raise ValueError('unknown action') return ret_data
@property @abstractmethod def method_initializable(self): """Indicate if the method can be initialized. Returns ------- method_initializable : bool """ raise NotImplementedError @abstractmethod def _init_method(self): raise NotImplementedError @property @abstractmethod def method_initialized(self): """Indicate if the method is initialized. Returns ------- method_initialized : bool """ raise NotImplementedError @property @abstractmethod def param_optim_stopped(self): """Indicate if optimization of kernel hyperparameters is stopped. Returns ------- param_optim_stopped : bool """ raise NotImplementedError @abstractmethod def reset(self, mode, axes, offset, limits, level_backgr=None, thresh_intens=None, travel_cost_max=None, scenario_name=None): raise NotImplementedError @abstractmethod def result(self, locs, counts, matrices_ellipses, travel_time, counting_time, travel_cost_grid=None, travel_cost_values=None): raise NotImplementedError @abstractmethod def next_loc(self): raise NotImplementedError @abstractmethod def heuris_experi_param(self): raise NotImplementedError @abstractmethod def problem_locs(self, locs, matrices_ellipses): raise NotImplementedError @abstractmethod def state_internal(self, data): raise NotImplementedError @abstractmethod def stop(self): raise NotImplementedError
[docs] @abstractmethod def save_state(self, id=None): """Save the current state of the server to a file. If `id` is not None, it is included in the filename. """ raise NotImplementedError
class BusyThread(Thread): """Base class for threads setting a server instance to `busy`.""" def __init__(self, server: TASServer): super().__init__() self.server = server def run(self): self.server.busy = True self._do_run() self.server.busy = False @abstractmethod def _do_run(self): return NotImplementedError class InitThread(BusyThread): """Thread class for initialization of method.""" def __init__(self, server: TASServer): super().__init__(server=server) def _do_run(self): self.server._init_method() self.server.save_state(id='0_init') class ResultThread(BusyThread): """Thread class for processing new results.""" def __init__(self, server: TASServer, *args): super().__init__(server=server) self.args = args def _do_run(self): self.server.result(*self.args) self.server.save_state(id='1_interm')