Source code for ariane.lib.communication._communication
from ..utils.config import ConfigMixin
from ..utils.logging import LoggerMixin
from abc import abstractmethod
import json
from typing import Any
import zmq
[docs]
class Server(ConfigMixin, LoggerMixin):
"""Base class for server communication."""
def __init__(self, service_id: str, port: int, *, config: Any):
ConfigMixin.__init__(self, config)
LoggerMixin.__init__(self, name="server")
self.service_id = service_id
self.port = port
# setup zmq server connection
self._ctx = zmq.Context()
self._rep = self._ctx.socket(zmq.REP)
self._rep.bind(f'tcp://0.0.0.0:{self.port}')
[docs]
def run_service(self):
"""Runs the service in an infinite loop while receiving and
sending messages according to our predefined format.
Message format: [`service id`, empty, `action`, `data`]
"""
while True:
self.logger.info("Wait for requests...")
msg = self._rep.recv_multipart()
self.logger.info(f"Received: {msg[:3]}")
try:
assert len(msg) == 4
assert msg[0].decode() == self.service_id
action = msg[2].decode()
data = json.loads(msg[3].decode())
# call template method for further processing and
# store the return data
ret_data = self.process(action, data)
ret_data['success'] = True
except Exception as err:
self.logger.error(str(err))
ret_data = {'success': False, 'error': str(err)}
ret_data = json.dumps(ret_data)
# send return data back to the client
msg = [msg[0], b'', msg[2], ret_data.encode()]
self._rep.send_multipart(msg)
self.logger.info(f"Sent: {msg[:3]}")
[docs]
@abstractmethod
def process(self, action: str, data: dict):
"""Template method for processing incoming messages."""
raise NotImplementedError
[docs]
class Client(LoggerMixin):
"""Base class for client communication.
It contains methods for communicating (sending, receiving) with a server
according to the predefined message format.
For the format, please see the server's documentation.
"""
def __init__(self, service_id: str, port: int, ip_server: str):
super().__init__(name="client")
self.service_id = service_id
self.port = port
self.ip_server = ip_server
# setup zmq client connection
self._ctx = zmq.Context()
self._req = self._ctx.socket(zmq.REQ)
self._req.connect(f"tcp://{self.ip_server}:{self.port}")
def send(self, action: str, data: dict):
data = json.dumps(data)
msg = [self.service_id.encode(), b'', action.encode(), data.encode()]
self._req.send_multipart(msg)
self.logger.info(f"Sent: {msg[:3]}")
def receive(self):
msg = self._req.recv_multipart()
self.logger.info(f"Received: {msg[:3]}")
assert len(msg) == 4
assert msg[0].decode() == self.service_id
action = msg[2].decode()
data = json.loads(msg[3].decode())
if not data['success']:
self.logger.error(f"Error at server: {data['error']}")
exit()
return action, data