Source code for ariane.lib.communication._communication

from ..utils.config import ConfigMixin
from ..utils.logging import LoggerMixin

from abc import abstractmethod
import json
from typing import Any
import zmq


[docs] class Server(ConfigMixin, LoggerMixin): """Base class for server communication.""" def __init__(self, service_id: str, port: int, *, config: Any): ConfigMixin.__init__(self, config) LoggerMixin.__init__(self, name="server") self.service_id = service_id self.port = port # setup zmq server connection self._ctx = zmq.Context() self._rep = self._ctx.socket(zmq.REP) self._rep.bind(f'tcp://0.0.0.0:{self.port}')
[docs] def run_service(self): """Runs the service in an infinite loop while receiving and sending messages according to our predefined format. Message format: [`service id`, empty, `action`, `data`] """ while True: self.logger.info("Wait for requests...") msg = self._rep.recv_multipart() self.logger.info(f"Received: {msg[:3]}") try: assert len(msg) == 4 assert msg[0].decode() == self.service_id action = msg[2].decode() data = json.loads(msg[3].decode()) # call template method for further processing and # store the return data ret_data = self.process(action, data) ret_data['success'] = True except Exception as err: self.logger.error(str(err)) ret_data = {'success': False, 'error': str(err)} ret_data = json.dumps(ret_data) # send return data back to the client msg = [msg[0], b'', msg[2], ret_data.encode()] self._rep.send_multipart(msg) self.logger.info(f"Sent: {msg[:3]}")
[docs] @abstractmethod def process(self, action: str, data: dict): """Template method for processing incoming messages.""" raise NotImplementedError
[docs] class Client(LoggerMixin): """Base class for client communication. It contains methods for communicating (sending, receiving) with a server according to the predefined message format. For the format, please see the server's documentation. """ def __init__(self, service_id: str, port: int, ip_server: str): super().__init__(name="client") self.service_id = service_id self.port = port self.ip_server = ip_server # setup zmq client connection self._ctx = zmq.Context() self._req = self._ctx.socket(zmq.REQ) self._req.connect(f"tcp://{self.ip_server}:{self.port}") def send(self, action: str, data: dict): data = json.dumps(data) msg = [self.service_id.encode(), b'', action.encode(), data.encode()] self._req.send_multipart(msg) self.logger.info(f"Sent: {msg[:3]}") def receive(self): msg = self._req.recv_multipart() self.logger.info(f"Received: {msg[:3]}") assert len(msg) == 4 assert msg[0].decode() == self.service_id action = msg[2].decode() data = json.loads(msg[3].decode()) if not data['success']: self.logger.error(f"Error at server: {data['error']}") exit() return action, data