# Pylm, a framework to build components for high performance distributed
# applications. Copyright (C) 2016 NFQ Solutions
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from pylm.parts.core import zmq_context
from pylm.parts.messages_pb2 import PalmMessage
from threading import Thread
from uuid import uuid4
import logging
import time
import zmq
import sys
[docs]class Client(object):
"""
Client to connect to parallel servers
:param server_name: Server you are connecting to
:param db_address: Address for the cache service, for first connection or configuration.
:param push_address: Address of the push service of the server to pull from
:param sub_address: Address of the pub service of the server to subscribe to
:param session: Name of the pipeline if the session has to be reused
:param logging_level: Specify the logging level.
:param this_config: Do not fetch configuration from the server
"""
def __init__(self, server_name: str,
db_address: str,
push_address: str=None,
sub_address: str=None,
session: str=None,
logging_level: int=logging.INFO,
this_config=False):
self.server_name = server_name
self.db_address = db_address
if session:
self.pipeline = session
self.session_set = True
else:
self.pipeline = str(uuid4())
self.session_set = False
self.uuid = str(uuid4())
self.db = zmq_context.socket(zmq.REQ)
self.db.identity = self.uuid.encode('utf-8')
self.db.connect(db_address)
self.sub_address = sub_address
self.push_address = push_address
# Basic console logging
self.logger = logging.getLogger(name=self.uuid)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
self.logger.addHandler(handler)
self.logger.setLevel(logging_level)
if this_config:
self.logger.warning('Not fetching config from the server')
else:
self.logger.info('Fetching configuration from the server')
self._get_config_from_master()
# PUB-SUB takes a while
time.sleep(0.5)
def _get_config_from_master(self):
name = self.get('name').decode('utf-8')
if not name == self.server_name:
raise ValueError('You are connecting to the wrong server')
if not self.sub_address:
self.sub_address = self.get('pub_address').decode('utf-8')
self.logger.info(
'CLIENT {}: Got subscription address: {}'.format(
self.uuid,
self.sub_address)
)
if not self.push_address:
self.push_address = self.get('pull_address').decode('utf-8')
self.logger.info(
'CLIENT {}: Got push address: {}'.format(
self.uuid,
self.push_address)
)
return {'sub_address': self.sub_address,
'push_address': self.push_address}
def clean(self):
self.db.close()
def _sender(self, socket, function, generator, cache):
for payload in generator:
message = PalmMessage()
message.function = function
message.stage = 0
message.pipeline = self.pipeline
message.client = self.uuid
message.payload = payload
if cache:
message.cache = cache
socket.send(message.SerializeToString())
[docs] def job(self, function, generator, messages: int=sys.maxsize, cache: str=''):
"""
Submit a job with multiple messages to a server.
:param function: Sting or list of strings following the format
``server.function``.
:param payload: A generator that yields a series of binary messages.
:param messages: Number of messages expected to be sent back to the
client. Defaults to infinity (sys.maxsize)
:param cache: Cache data included in the message
:return: an iterator with the messages that are sent back to the client.
"""
push_socket = zmq_context.socket(zmq.PUSH)
push_socket.connect(self.push_address)
sub_socket = zmq_context.socket(zmq.SUB)
sub_socket.setsockopt_string(zmq.SUBSCRIBE, self.uuid)
sub_socket.connect(self.sub_address)
if type(function) == str:
# Single-stage job
pass
elif type(function) == list:
# Pipelined job.
function = ' '.join(function)
# Remember that sockets are not thread safe
sender_thread = Thread(target=self._sender,
args=(push_socket, function, generator, cache))
# Sender runs in background.
sender_thread.start()
for i in range(messages):
[client, message_data] = sub_socket.recv_multipart()
if not client.decode('utf-8') == self.uuid:
raise ValueError('The client got a message that does not belong')
message = PalmMessage()
message.ParseFromString(message_data)
yield message.payload
[docs] def eval(self, function, payload: bytes, messages: int=1, cache: str=''):
"""
Execute single job.
:param function: Sting or list of strings following the format
``server.function``.
:param payload: Binary message to be sent
:param messages: Number of messages expected to be sent back to the
client
:param cache: Cache data included in the message
:return: If messages=1, the result data. If messages > 1, a list with the results
"""
push_socket = zmq_context.socket(zmq.PUSH)
push_socket.connect(self.push_address)
sub_socket = zmq_context.socket(zmq.SUB)
sub_socket.setsockopt_string(zmq.SUBSCRIBE, self.uuid)
sub_socket.connect(self.sub_address)
if type(function) == str:
# Single-stage job
pass
elif type(function) == list:
# Pipelined job.
function = ' '.join(function)
message = PalmMessage()
message.function = function
message.stage = 0
message.pipeline = self.pipeline
message.client = self.uuid
message.payload = payload
if cache:
message.cache = cache
push_socket.send(message.SerializeToString())
result = []
for i in range(messages):
[client, message_data] = sub_socket.recv_multipart()
message.ParseFromString(message_data)
result.append(message.payload)
if messages == 1:
return result[0]
else:
return result
[docs] def set(self, value: bytes, key=None):
"""
Sets a key value pare in the remote database. If the key is not set,
the function returns a new key. Note that the order of the arguments
is reversed from the usual.
.. warning::
If the session attribute is specified, all the keys will be
prepended with the session id.
:param value: Value to be stored
:param key: Key for the k-v storage
:return: New key or the same key
"""
if not type(value) == bytes:
raise TypeError('First argument {} must be of type <bytes>'.format(value))
message = PalmMessage()
message.pipeline = str(uuid4()) # For a set job, the pipeline is not important
message.client = self.uuid
message.stage = 0
message.function = '.'.join([self.server_name, 'set'])
message.payload = value
if key and self.session_set:
message.cache = ''.join([self.pipeline, key])
elif key:
message.cache = key
self.db.send(message.SerializeToString())
return self.db.recv().decode('utf-8')
[docs] def get(self, key):
"""
Gets a value from server's internal cache
:param key: Key for the data to be selected.
:return: Value
"""
message = PalmMessage()
message.pipeline = str(uuid4())
message.client = self.uuid
message.stage = 0
message.function = '.'.join([self.server_name, 'get'])
message.payload = key.encode('utf-8')
self.db.send(message.SerializeToString())
return self.db.recv()
[docs] def delete(self, key):
"""
Deletes data in the server's internal cache.
:param key: Key of the data to be deleted
:return:
"""
message = PalmMessage()
message.pipeline = str(uuid4())
message.client = self.uuid
message.stage = 0
message.function = '.'.join([self.server_name, 'delete'])
message.payload = key.encode('utf-8')
self.db.send(message.SerializeToString())
return self.db.recv().decode('utf-8')