import asyncio
import functools
import inspect
import logging
from abc import ABC, abstractmethod
from collections.abc import Awaitable
from concurrent.futures import ThreadPoolExecutor
import zmq
import zmq.asyncio
from qulab.exceptions import (QuLabRPCError, QuLabRPCServerError,
QuLabRPCTimeout)
from qulab.serialize import pack, unpack
from qulab.utils import acceptArg, randomID
log = logging.getLogger(__name__) # pylint: disable=invalid-name
# message type
RPC_REQUEST = b'\x01'
RPC_RESPONSE = b'\x02'
RPC_PING = b'\x03'
RPC_PONG = b'\x04'
RPC_CANCEL = b'\x05'
RPC_SHUTDOWN = b'\x06'
[docs]class RPCMixin(ABC):
__pending = None
__tasks = None
@property
def pending(self):
if self.__pending is None:
self.__pending = {}
return self.__pending
@property
def tasks(self):
if self.__tasks is None:
self.__tasks = {}
return self.__tasks
[docs] def close(self):
self.stop()
for task in self.tasks.values():
try:
task.cancel()
finally:
pass
for fut, timeout in self.pending.values():
try:
fut.cancel()
timeout.cancel()
finally:
pass
[docs] def createTask(self, msgID, coro, timeout=0):
"""
Create a new task for msgID.
"""
if timeout > 0:
coro = asyncio.wait_for(coro, timeout)
task = asyncio.ensure_future(coro, loop=self.loop)
self.tasks[msgID] = task
def clean(fut, msgID=msgID):
if msgID in self.tasks:
del self.tasks[msgID]
task.add_done_callback(clean)
[docs] def cancelTask(self, msgID):
"""
Cancel the task for msgID.
"""
if msgID in self.tasks:
self.tasks[msgID].cancel()
del self.tasks[msgID]
[docs] def createPending(self, addr, msgID, timeout=1, cancelRemote=True):
"""
Create a future for request, wait response before timeout.
"""
fut = self.loop.create_future()
self.pending[msgID] = (fut,
self.loop.call_later(timeout,
self.cancelPending, addr,
msgID, cancelRemote))
return fut
[docs] def cancelPending(self, addr, msgID, cancelRemote):
"""
Give up when request timeout and try to cancel remote task.
"""
fut, timeout = self.pending[msgID]
if cancelRemote:
self.cancelRemoteTask(addr, msgID)
fut.set_exception(QuLabRPCTimeout('Time out.'))
del self.pending[msgID]
[docs] def cancelRemoteTask(self, addr, msgID):
"""
Try to cancel remote task.
"""
asyncio.ensure_future(self.sendto(RPC_CANCEL + msgID, addr),
loop=self.loop)
@property
@abstractmethod
def loop(self):
"""
Event loop.
"""
[docs] @abstractmethod
async def sendto(self, data, address):
"""
Send message to address.
"""
__rpc_handlers = {
RPC_PING: 'on_ping',
RPC_PONG: 'on_pong',
RPC_REQUEST: 'on_request',
RPC_RESPONSE: 'on_response',
RPC_CANCEL: 'on_cancel',
RPC_SHUTDOWN: 'on_shutdown',
}
[docs] def handle(self, source, data):
"""
Handle received data.
Should be called whenever received data from outside.
"""
msg_type, data = data[:1], data[1:]
log.debug(f'received request {msg_type} from {source}')
handler = self.__rpc_handlers.get(msg_type, None)
if handler is not None:
getattr(self, handler)(source, data)
[docs] async def ping(self, addr, timeout=1):
await self.sendto(RPC_PING, addr)
fut = self.createPending(addr, addr, timeout, False)
try:
return await fut
except QuLabRPCTimeout:
return False
[docs] async def pong(self, addr):
await self.sendto(RPC_PONG, addr)
[docs] async def request(self, address, msgID, msg):
log.debug(f'send request {address}, {msgID.hex()}, {msg}')
await self.sendto(RPC_REQUEST + msgID + msg, address)
[docs] async def response(self, address, msgID, msg):
log.debug(f'send response {address}, {msgID.hex()}, {msg}')
await self.sendto(RPC_RESPONSE + msgID + msg, address)
[docs] async def shutdown(self, address):
await self.sendto(RPC_SHUTDOWN, address)
[docs] def on_request(self, source, data):
"""
Handle request.
Overwrite this method on server.
"""
[docs] def on_response(self, source, data):
"""
Handle response.
Overwrite this method on client.
"""
[docs] def on_ping(self, source, data):
log.debug(f"received ping from {source}")
asyncio.ensure_future(self.pong(source), loop=self.loop)
[docs] def on_pong(self, source, data):
log.debug(f"received pong from {source}")
if source in self.pending:
fut, timeout = self.pending[source]
timeout.cancel()
fut.set_result(True)
del self.pending[source]
[docs] def on_cancel(self, source, data):
msgID = data[:20]
self.cancelTask(msgID)
[docs] def on_shutdown(self, source, data):
if self.is_admin(source, data):
raise SystemExit(0)
[docs] def is_admin(self, source, data):
return True
[docs]class RPCClientMixin(RPCMixin):
_client_defualt_timeout = 10
[docs] def set_timeout(self, timeout=10):
self._client_defualt_timeout = timeout
[docs] def remoteCall(self, addr, methodNane, args=(), kw=None):
if kw is None:
kw = {}
if 'timeout' in kw:
timeout = kw['timeout']
else:
timeout = self._client_defualt_timeout
msg = pack((methodNane, args, kw))
msgID = randomID()
asyncio.ensure_future(self.request(addr, msgID, msg), loop=self.loop)
return self.createPending(addr, msgID, timeout)
[docs] def on_response(self, source, data):
"""
Client side.
"""
msgID, msg = data[:20], data[20:]
if msgID not in self.pending:
return
fut, timeout = self.pending[msgID]
result = unpack(msg)
timeout.cancel()
if isinstance(result, Exception):
fut.set_exception(result)
else:
fut.set_result(result)
del self.pending[msgID]
[docs]class RPCServerMixin(RPCMixin):
def _unpack_request(self, msg):
try:
method, args, kw = unpack(msg)
except:
raise QuLabRPCError("Could not read packet: %r" % msg)
return method, args, kw
@property
def executor(self):
return None
[docs] @abstractmethod
def getRequestHandler(self, methodNane, source, msgID):
"""
Get suitable handler for request.
You should implement this method yourself.
"""
[docs] def on_request(self, source, data):
"""
Received a request from source.
"""
msgID, msg = data[:20], data[20:]
try:
method, args, kw = self._unpack_request(msg)
self.createTask(msgID,
self.handle_request(source, msgID, method, *args,
**kw),
timeout=kw.get('timeout', 0))
except Exception as e:
self.response(source, msgID, pack(QuLabRPCServerError.make(e)))
[docs] async def handle_request(self, source, msgID, method, *args, **kw):
"""
Handle a request from source.
"""
try:
func = self.getRequestHandler(method, source=source, msgID=msgID)
if 'timeout' in kw and not acceptArg(func, 'timeout'):
del kw['timeout']
if inspect.iscoroutinefunction(func):
result = await func(*args, **kw)
else:
result = await self.loop.run_in_executor(
self.executor, functools.partial(func, *args, **kw))
if isinstance(result, Awaitable):
result = await result
except QuLabRPCError as e:
result = e
except Exception as e:
result = QuLabRPCServerError.make(e)
msg = pack(result)
await self.response(source, msgID, msg)
[docs]class ZMQServer(RPCServerMixin):
def __init__(self, loop=None):
self.zmq_main_task = None
self.zmq_ctx = None
self.zmq_socket = None
self._port = 0
self._loop = loop or asyncio.get_event_loop()
self._module = None
self._executor = None
@property
def executor(self):
#if self._executor is None:
# self._executor = ThreadPoolExecutor(
# initializer=lambda loop: asyncio.set_event_loop(loop),
# initargs=(self.loop, ))
return self._executor
[docs] def set_module(self, mod):
self._module = mod
[docs] async def sendto(self, data, address):
self.zmq_socket.send_multipart([address, data])
[docs] def getRequestHandler(self, methodNane, **kw):
path = methodNane.split('.')
ret = getattr(self._module, path[0])
for n in path[1:]:
ret = getattr(ret, n)
return ret
@property
def loop(self):
return self._loop
@property
def port(self):
return self._port
[docs] def set_socket(self, sock):
self.zmq_socket = sock
[docs] def start(self):
super().start()
self.zmq_ctx = zmq.asyncio.Context.instance()
self.zmq_main_task = asyncio.ensure_future(self.run(), loop=self.loop)
[docs] def stop(self):
if self._executor is not None:
self._executor.shutdown()
if self.zmq_main_task is not None and not self.zmq_main_task.done():
self.zmq_main_task.cancel()
super().stop()
[docs] async def run(self):
with self.zmq_ctx.socket(zmq.ROUTER, io_loop=self._loop) as sock:
sock.setsockopt(zmq.LINGER, 0)
self._port = sock.bind_to_random_port('tcp://*')
self.set_socket(sock)
while True:
addr, data = await sock.recv_multipart()
log.debug('received data from %r' % addr.hex())
self.handle(addr, data)
[docs]class ZMQRPCCallable:
def __init__(self, methodNane, owner):
self.methodNane = methodNane
self.owner = owner
def __call__(self, *args, **kw):
return self.owner.performMethod(self.methodNane, args, kw)
def __getattr__(self, name):
return ZMQRPCCallable(f"{self.methodNane}.{name}", self.owner)
[docs]class ZMQClient(RPCClientMixin):
def __init__(self, addr, timeout=10, loop=None):
self._loop = loop or asyncio.get_event_loop()
self.set_timeout(timeout)
self.addr = addr
self._ctx = zmq.asyncio.Context()
self.zmq_socket = self._ctx.socket(zmq.DEALER, io_loop=self._loop)
self.zmq_socket.setsockopt(zmq.LINGER, 0)
self.zmq_socket.connect(self.addr)
self.zmq_main_task = asyncio.ensure_future(self.run(), loop=self.loop)
def __del__(self):
self.zmq_socket.close()
self.close()
self.zmq_main_task.cancel()
@property
def loop(self):
return self._loop
[docs] async def ping(self, timeout=1):
return await super().ping(self.addr, timeout=timeout)
[docs] async def sendto(self, data, addr):
await self.zmq_socket.send_multipart([data])
[docs] async def run(self):
while True:
data, = await self.zmq_socket.recv_multipart()
self.handle(self.addr, data)
def __getattr__(self, name):
return ZMQRPCCallable(name, self)