Source code for qulab.rpc

import asyncio
import functools
import inspect
import logging
from abc import ABC, abstractmethod
from collections.abc import Awaitable
from concurrent.futures import ThreadPoolExecutor

import zmq
import zmq.asyncio

from qulab.exceptions import (QuLabRPCError, QuLabRPCServerError,
                              QuLabRPCTimeout)
from qulab.serialize import pack, unpack
from qulab.utils import acceptArg, randomID

log = logging.getLogger(__name__)  # pylint: disable=invalid-name

# message type

RPC_REQUEST = b'\x01'
RPC_RESPONSE = b'\x02'
RPC_PING = b'\x03'
RPC_PONG = b'\x04'
RPC_CANCEL = b'\x05'
RPC_SHUTDOWN = b'\x06'


[docs]class RPCMixin(ABC): __pending = None __tasks = None @property def pending(self): if self.__pending is None: self.__pending = {} return self.__pending @property def tasks(self): if self.__tasks is None: self.__tasks = {} return self.__tasks
[docs] def start(self): pass
[docs] def stop(self): pass
[docs] def close(self): self.stop() for task in self.tasks.values(): try: task.cancel() finally: pass for fut, timeout in self.pending.values(): try: fut.cancel() timeout.cancel() finally: pass
[docs] def createTask(self, msgID, coro, timeout=0): """ Create a new task for msgID. """ if timeout > 0: coro = asyncio.wait_for(coro, timeout) task = asyncio.ensure_future(coro, loop=self.loop) self.tasks[msgID] = task def clean(fut, msgID=msgID): if msgID in self.tasks: del self.tasks[msgID] task.add_done_callback(clean)
[docs] def cancelTask(self, msgID): """ Cancel the task for msgID. """ if msgID in self.tasks: self.tasks[msgID].cancel() del self.tasks[msgID]
[docs] def createPending(self, addr, msgID, timeout=1, cancelRemote=True): """ Create a future for request, wait response before timeout. """ fut = self.loop.create_future() self.pending[msgID] = (fut, self.loop.call_later(timeout, self.cancelPending, addr, msgID, cancelRemote)) return fut
[docs] def cancelPending(self, addr, msgID, cancelRemote): """ Give up when request timeout and try to cancel remote task. """ fut, timeout = self.pending[msgID] if cancelRemote: self.cancelRemoteTask(addr, msgID) fut.set_exception(QuLabRPCTimeout('Time out.')) del self.pending[msgID]
[docs] def cancelRemoteTask(self, addr, msgID): """ Try to cancel remote task. """ asyncio.ensure_future(self.sendto(RPC_CANCEL + msgID, addr), loop=self.loop)
@property @abstractmethod def loop(self): """ Event loop. """
[docs] @abstractmethod async def sendto(self, data, address): """ Send message to address. """
__rpc_handlers = { RPC_PING: 'on_ping', RPC_PONG: 'on_pong', RPC_REQUEST: 'on_request', RPC_RESPONSE: 'on_response', RPC_CANCEL: 'on_cancel', RPC_SHUTDOWN: 'on_shutdown', }
[docs] def handle(self, source, data): """ Handle received data. Should be called whenever received data from outside. """ msg_type, data = data[:1], data[1:] log.debug(f'received request {msg_type} from {source}') handler = self.__rpc_handlers.get(msg_type, None) if handler is not None: getattr(self, handler)(source, data)
[docs] async def ping(self, addr, timeout=1): await self.sendto(RPC_PING, addr) fut = self.createPending(addr, addr, timeout, False) try: return await fut except QuLabRPCTimeout: return False
[docs] async def pong(self, addr): await self.sendto(RPC_PONG, addr)
[docs] async def request(self, address, msgID, msg): log.debug(f'send request {address}, {msgID.hex()}, {msg}') await self.sendto(RPC_REQUEST + msgID + msg, address)
[docs] async def response(self, address, msgID, msg): log.debug(f'send response {address}, {msgID.hex()}, {msg}') await self.sendto(RPC_RESPONSE + msgID + msg, address)
[docs] async def shutdown(self, address): await self.sendto(RPC_SHUTDOWN, address)
[docs] def on_request(self, source, data): """ Handle request. Overwrite this method on server. """
[docs] def on_response(self, source, data): """ Handle response. Overwrite this method on client. """
[docs] def on_ping(self, source, data): log.debug(f"received ping from {source}") asyncio.ensure_future(self.pong(source), loop=self.loop)
[docs] def on_pong(self, source, data): log.debug(f"received pong from {source}") if source in self.pending: fut, timeout = self.pending[source] timeout.cancel() fut.set_result(True) del self.pending[source]
[docs] def on_cancel(self, source, data): msgID = data[:20] self.cancelTask(msgID)
[docs] def on_shutdown(self, source, data): if self.is_admin(source, data): raise SystemExit(0)
[docs] def is_admin(self, source, data): return True
[docs]class RPCClientMixin(RPCMixin): _client_defualt_timeout = 10
[docs] def set_timeout(self, timeout=10): self._client_defualt_timeout = timeout
[docs] def remoteCall(self, addr, methodNane, args=(), kw=None): if kw is None: kw = {} if 'timeout' in kw: timeout = kw['timeout'] else: timeout = self._client_defualt_timeout msg = pack((methodNane, args, kw)) msgID = randomID() asyncio.ensure_future(self.request(addr, msgID, msg), loop=self.loop) return self.createPending(addr, msgID, timeout)
[docs] def on_response(self, source, data): """ Client side. """ msgID, msg = data[:20], data[20:] if msgID not in self.pending: return fut, timeout = self.pending[msgID] result = unpack(msg) timeout.cancel() if isinstance(result, Exception): fut.set_exception(result) else: fut.set_result(result) del self.pending[msgID]
[docs]class RPCServerMixin(RPCMixin): def _unpack_request(self, msg): try: method, args, kw = unpack(msg) except: raise QuLabRPCError("Could not read packet: %r" % msg) return method, args, kw @property def executor(self): return None
[docs] @abstractmethod def getRequestHandler(self, methodNane, source, msgID): """ Get suitable handler for request. You should implement this method yourself. """
[docs] def on_request(self, source, data): """ Received a request from source. """ msgID, msg = data[:20], data[20:] try: method, args, kw = self._unpack_request(msg) self.createTask(msgID, self.handle_request(source, msgID, method, *args, **kw), timeout=kw.get('timeout', 0)) except Exception as e: self.response(source, msgID, pack(QuLabRPCServerError.make(e)))
[docs] async def handle_request(self, source, msgID, method, *args, **kw): """ Handle a request from source. """ try: func = self.getRequestHandler(method, source=source, msgID=msgID) if 'timeout' in kw and not acceptArg(func, 'timeout'): del kw['timeout'] if inspect.iscoroutinefunction(func): result = await func(*args, **kw) else: result = await self.loop.run_in_executor( self.executor, functools.partial(func, *args, **kw)) if isinstance(result, Awaitable): result = await result except QuLabRPCError as e: result = e except Exception as e: result = QuLabRPCServerError.make(e) msg = pack(result) await self.response(source, msgID, msg)
[docs]class ZMQServer(RPCServerMixin): def __init__(self, loop=None): self.zmq_main_task = None self.zmq_ctx = None self.zmq_socket = None self._port = 0 self._loop = loop or asyncio.get_event_loop() self._module = None self._executor = None @property def executor(self): #if self._executor is None: # self._executor = ThreadPoolExecutor( # initializer=lambda loop: asyncio.set_event_loop(loop), # initargs=(self.loop, )) return self._executor
[docs] def set_module(self, mod): self._module = mod
[docs] async def sendto(self, data, address): self.zmq_socket.send_multipart([address, data])
[docs] def getRequestHandler(self, methodNane, **kw): path = methodNane.split('.') ret = getattr(self._module, path[0]) for n in path[1:]: ret = getattr(ret, n) return ret
@property def loop(self): return self._loop @property def port(self): return self._port
[docs] def set_socket(self, sock): self.zmq_socket = sock
[docs] def start(self): super().start() self.zmq_ctx = zmq.asyncio.Context.instance() self.zmq_main_task = asyncio.ensure_future(self.run(), loop=self.loop)
[docs] def stop(self): if self._executor is not None: self._executor.shutdown() if self.zmq_main_task is not None and not self.zmq_main_task.done(): self.zmq_main_task.cancel() super().stop()
[docs] async def run(self): with self.zmq_ctx.socket(zmq.ROUTER, io_loop=self._loop) as sock: sock.setsockopt(zmq.LINGER, 0) self._port = sock.bind_to_random_port('tcp://*') self.set_socket(sock) while True: addr, data = await sock.recv_multipart() log.debug('received data from %r' % addr.hex()) self.handle(addr, data)
[docs]class ZMQRPCCallable: def __init__(self, methodNane, owner): self.methodNane = methodNane self.owner = owner def __call__(self, *args, **kw): return self.owner.performMethod(self.methodNane, args, kw) def __getattr__(self, name): return ZMQRPCCallable(f"{self.methodNane}.{name}", self.owner)
[docs]class ZMQClient(RPCClientMixin): def __init__(self, addr, timeout=10, loop=None): self._loop = loop or asyncio.get_event_loop() self.set_timeout(timeout) self.addr = addr self._ctx = zmq.asyncio.Context() self.zmq_socket = self._ctx.socket(zmq.DEALER, io_loop=self._loop) self.zmq_socket.setsockopt(zmq.LINGER, 0) self.zmq_socket.connect(self.addr) self.zmq_main_task = asyncio.ensure_future(self.run(), loop=self.loop) def __del__(self): self.zmq_socket.close() self.close() self.zmq_main_task.cancel() @property def loop(self): return self._loop
[docs] async def ping(self, timeout=1): return await super().ping(self.addr, timeout=timeout)
[docs] async def sendto(self, data, addr): await self.zmq_socket.send_multipart([data])
[docs] async def run(self): while True: data, = await self.zmq_socket.recv_multipart() self.handle(self.addr, data)
[docs] def performMethod(self, methodNane, args, kw): return self.remoteCall(self.addr, methodNane, args, kw)
def __getattr__(self, name): return ZMQRPCCallable(name, self)