from ....lib.communication import Server
from abc import abstractmethod
from datetime import datetime
from numpy.random import RandomState
from threading import Thread
PORT_DEFAULT = 11657
[docs]
class TASServer(Server):
"""Base class for servers in TAS application.
It implements processing actions with corresponding data and
accordingly calls template methods of derived classes.
"""
def __init__(self, *, version, method, config, random_seed=None, port=PORT_DEFAULT):
super().__init__('ARIANE', port, config=config)
self.version = version
self.method = method
# random state for reproducibility
self.random_seed = random_seed
self.random_state = RandomState(random_seed)
self.busy = False
[docs]
def process(self, action: str, data: dict):
"""Process `action` with corresponding `data` and accordingly
call template methods of derived classes.
"""
ret_data = {}
if self.busy:
ret_data['busy'] = True
else:
if action == 'ping':
ret_data['method'] = self.method
ret_data['version'] = self.version
elif action == 'reset':
# reset server instance for fresh use
self.start_time = datetime.now()
self.end_time = None
self.logger.info(
f"Start time: {self.start_time.isoformat(' ', timespec='seconds')}")
mode = data['mode']
axes = data['axes']
offset = data['offset']
limits = data['limits']
# set background level. If not specified, set it to `None`.
level_backgr = data['level_backgr'] if 'level_backgr' in data else None
# set intensity threshold. If not specified, set it to `None`.
thresh_intens = data['thresh_intens'] if 'thresh_intens' in data else None
# set maximum travel cost for normalizing travel costs.
# if not specified, set it to `1`.
travel_cost_max = data['travel_cost_max'] \
if 'travel_cost_max' in data and data['travel_cost_max'] is not None else 1
scenario_name = data['scenario_name'] if 'scenario_name' in data else None
assert len(axes) == len(limits)
assert level_backgr is None or level_backgr >= 0
assert thresh_intens is None or thresh_intens >= 0
assert travel_cost_max > 0
self.reset(mode, axes, offset, limits, level_backgr, thresh_intens,
travel_cost_max, scenario_name)
elif action == 'result':
# process new results, if necessary in a "busy thread"
locs = data['locs']
counts = data['counts']
matrices_ellipses = data['matrices_ellipses']
travel_time = data['travel_time']
counting_time = data['counting_time']
travel_cost_grid = data['travel_cost_grid'] if 'travel_cost_grid' in data else None
travel_cost_values = data['travel_cost_values'] if 'travel_cost_values' in data else None
assert len(locs) == len(counts) == len(matrices_ellipses)
args = locs, counts, matrices_ellipses, travel_time, counting_time, \
travel_cost_grid, travel_cost_values
if not self.method_initialized or self.param_optim_stopped:
self.result(*args)
self.save_state(id='1_interm')
else:
ResultThread(self, *args).start()
elif action == 'next_loc':
# compute next location based on current results or initialize method in a
# "busy thread" if not done yet
if not self.method_initialized:
if self.method_initializable:
InitThread(self).start()
ret_data['busy'] = True
else:
raise Exception('method is not initializable')
else:
ret_data['loc'], ret_data['stop'] = self.next_loc()
elif action == 'heuris_experi_param':
# heuristically compute experiment parameters based on current results
ret_data['level_backgr'], ret_data['thresh_intens'] = self.heuris_experi_param()
elif action == 'problem_locs':
# process problematic locations
locs = data['locs']
matrices_ellipses = data['matrices_ellipses']
self.problem_locs(locs, matrices_ellipses)
elif action == 'state_internal':
if not self.method_initialized:
raise RuntimeError("server is not initialized")
# provide internal state of server
ret_data = {**ret_data, **self.state_internal(data)} # merge dictionaries
elif action == 'stop':
# finalize
# log final information
self.end_time = datetime.now()
self.logger.info(f"End time: {self.end_time.isoformat(' ', timespec='seconds')}")
self.logger.info(
f"Total hours exceeded: {(self.end_time-self.start_time).total_seconds() / 3600.}")
self.save_state(id='2_final')
self.stop()
else:
raise ValueError('unknown action')
return ret_data
@property
@abstractmethod
def method_initializable(self):
"""Indicate if the method can be initialized.
Returns
-------
method_initializable : bool
"""
raise NotImplementedError
@abstractmethod
def _init_method(self):
raise NotImplementedError
@property
@abstractmethod
def method_initialized(self):
"""Indicate if the method is initialized.
Returns
-------
method_initialized : bool
"""
raise NotImplementedError
@property
@abstractmethod
def param_optim_stopped(self):
"""Indicate if optimization of kernel hyperparameters is stopped.
Returns
-------
param_optim_stopped : bool
"""
raise NotImplementedError
@abstractmethod
def reset(self, mode, axes, offset, limits, level_backgr=None, thresh_intens=None,
travel_cost_max=None, scenario_name=None):
raise NotImplementedError
@abstractmethod
def result(self, locs, counts, matrices_ellipses, travel_time, counting_time,
travel_cost_grid=None, travel_cost_values=None):
raise NotImplementedError
@abstractmethod
def next_loc(self):
raise NotImplementedError
@abstractmethod
def heuris_experi_param(self):
raise NotImplementedError
@abstractmethod
def problem_locs(self, locs, matrices_ellipses):
raise NotImplementedError
@abstractmethod
def state_internal(self, data):
raise NotImplementedError
@abstractmethod
def stop(self):
raise NotImplementedError
[docs]
@abstractmethod
def save_state(self, id=None):
"""Save the current state of the server to a file. If `id` is not None,
it is included in the filename.
"""
raise NotImplementedError
class BusyThread(Thread):
"""Base class for threads setting a server instance to `busy`."""
def __init__(self, server: TASServer):
super().__init__()
self.server = server
def run(self):
self.server.busy = True
self._do_run()
self.server.busy = False
@abstractmethod
def _do_run(self):
return NotImplementedError
class InitThread(BusyThread):
"""Thread class for initialization of method."""
def __init__(self, server: TASServer):
super().__init__(server=server)
def _do_run(self):
self.server._init_method()
self.server.save_state(id='0_init')
class ResultThread(BusyThread):
"""Thread class for processing new results."""
def __init__(self, server: TASServer, *args):
super().__init__(server=server)
self.args = args
def _do_run(self):
self.server.result(*self.args)
self.server.save_state(id='1_interm')