From dd8e9fb5e3d3ebbfb658b73ef1d5dc4f0ee6c5fc Mon Sep 17 00:00:00 2001
From: Kirill Bespalov <kbespalov@mirantis.com>
Date: Fri, 30 Sep 2016 18:19:58 +0300
Subject: [PATCH] [simulator] Automatic stopping of rpc-servers

This patch provide the sync flag:

  simulator.py rpc-client --sync call
  simulator.py rpc-client --sync fanout

The --sync values means next:
- call:   to send sync msg via rpc.call
- fanout: to broadcast sync msg via rpc.fanout to all servers on topic

When clients has sent all messages, the rpc-server will be stopped
automatically and dump statistics to a file (if --json-file is used).

This is much usefull than rough killing the process of the server in
benchmark frameworks.

Change-Id: I06fd8dbbcdc8b2b9f13029029f730b417ff128ce
---
 tools/simulator.py | 138 ++++++++++++++++++++++++++++++++++++++-------
 1 file changed, 119 insertions(+), 19 deletions(-)

diff --git a/tools/simulator.py b/tools/simulator.py
index c4c2308f4..f63028922 100755
--- a/tools/simulator.py
+++ b/tools/simulator.py
@@ -24,6 +24,7 @@ import os
 import random
 import signal
 import six
+import socket
 import string
 import sys
 import threading
@@ -39,6 +40,7 @@ from oslo_utils import timeutils
 LOG = logging.getLogger()
 RANDOM_GENERATOR = None
 CURRENT_PID = None
+CURRENT_HOST = None
 CLIENTS = []
 MESSAGES = []
 IS_RUNNING = True
@@ -320,6 +322,37 @@ class RpcEndpoint(object):
         return reply
 
 
+class ServerControlEndpoint(object):
+    def __init__(self, controlled_server):
+        self.connected_clients = set()
+        self.controlled_server = controlled_server
+
+    def sync_start(self, ctx, message):
+        """Handle start reports from clients"""
+
+        client_id = message['id']
+        LOG.info('The client %s started to send messages' % client_id)
+        self.connected_clients.add(client_id)
+
+    def sync_done(self, ctx, message):
+        """Handle done reports from clients"""
+
+        client_id = message['id']
+        LOG.info('The client %s finished msg sending.' % client_id)
+
+        if client_id in self.connected_clients:
+            self.connected_clients.remove(client_id)
+
+        if not self.connected_clients:
+            LOG.info(
+                'The clients sent all messages. Shutting down the server..')
+            threading.Timer(1, self._stop_server_with_delay).start()
+
+    def _stop_server_with_delay(self):
+        self.controlled_server.stop()
+        self.controlled_server.wait()
+
+
 class Client(object):
     def __init__(self, client_id, client, method, has_result,
                  wait_after_msg):
@@ -341,6 +374,12 @@ class Client(object):
             self.round_trip_messages = MessageStatsCollector(
                 'round-trip-%s' % client_id)
 
+    def host_based_id(self):
+        _id = "%(client_id)s %(salt)s@%(hostname)s"
+        return _id % {'hostname': CURRENT_HOST,
+                      'salt': hex(id(self))[2:],
+                      'client_id': self.client_id}
+
     def send_msg(self):
         msg = make_message(self.seq, MESSAGES[self.position], time.time())
         self.sent_messages.push(msg)
@@ -366,12 +405,55 @@ class Client(object):
 
 class RPCClient(Client):
     def __init__(self, client_id, transport, target, timeout, is_cast,
-                 wait_after_msg):
-        client = rpc.RPCClient(transport, target).prepare(timeout=timeout)
+                 wait_after_msg, sync_mode=False):
+
+        client = rpc.RPCClient(transport, target)
         method = _rpc_cast if is_cast else _rpc_call
 
-        super(RPCClient, self).__init__(client_id, client, method,
+        super(RPCClient, self).__init__(client_id,
+                                        client.prepare(timeout=timeout),
+                                        method,
                                         not is_cast, wait_after_msg)
+        self.sync_mode = sync_mode
+        self.is_sync = False
+
+        # prepare the sync client
+        if sync_mode:
+            if sync_mode == 'call':
+                self.sync_client = self.client
+            else:
+                self.sync_client = client.prepare(fanout=True, timeout=timeout)
+
+    def send_msg(self):
+        if self.sync_mode and not self.is_sync:
+            self.is_sync = self.sync_start()
+        super(RPCClient, self).send_msg()
+
+    def sync_start(self):
+        try:
+            msg = {'id': self.host_based_id()}
+            method = _rpc_call if self.sync_mode == 'call' else _rpc_cast
+            method(self.sync_client, msg, 'sync_start')
+        except Exception:
+            LOG.error('The client: %s failed to sync with %s.' %
+                      (self.client_id, self.client.target))
+            return False
+        LOG.info('The client: %s successfully sync with  %s' % (
+            self.client_id, self.client.target))
+        return True
+
+    def sync_done(self):
+        try:
+            msg = {'id': self.host_based_id()}
+            method = _rpc_call if self.sync_mode == 'call' else _rpc_cast
+            method(self.sync_client, msg, 'sync_done')
+        except Exception:
+            LOG.error('The client: %s failed finish the sync with %s.'
+                      % (self.client_id, self.client.target))
+            return False
+        LOG.info('The client: %s successfully finished sync with %s'
+                 % (self.client_id, self.client.target))
+        return True
 
 
 class NotifyClient(Client):
@@ -432,9 +514,13 @@ def run_server(server, duration=None):
 
 
 def rpc_server(transport, target, wait_before_answer, executor, duration):
+
     endpoints = [RpcEndpoint(wait_before_answer)]
-    server = rpc.get_rpc_server(transport, target, endpoints,
-                                executor=executor)
+    server = rpc.get_rpc_server(transport, target, endpoints, executor)
+
+    # make the rpc server controllable by rpc clients
+    endpoints.append(ServerControlEndpoint(server))
+
     LOG.debug("starting RPC server for target %s", target)
 
     run_server(server, duration=duration)
@@ -444,15 +530,18 @@ def rpc_server(transport, target, wait_before_answer, executor, duration):
 
 @wrap_sigexit
 def spawn_rpc_clients(threads, transport, targets, wait_after_msg, timeout,
-                      is_cast, messages_count, duration):
+                      is_cast, messages_count, duration, sync_mode):
     p = eventlet.GreenPool(size=threads)
     targets = itertools.cycle(targets)
+
     for i in six.moves.range(threads):
         target = next(targets)
         LOG.debug("starting RPC client for target %s", target)
         client_builder = functools.partial(RPCClient, i, transport, target,
-                                           timeout, is_cast, wait_after_msg)
-        p.spawn_n(send_messages, i, client_builder, messages_count, duration)
+                                           timeout, is_cast, wait_after_msg,
+                                           sync_mode)
+        p.spawn_n(send_messages, i, client_builder,
+                  messages_count, duration)
     p.waitall()
 
 
@@ -493,12 +582,17 @@ def send_messages(client_id, client_builder, messages_count, duration):
                 break
         LOG.debug("Client %d has sent %d messages", client_id, messages_count)
 
-    time.sleep(1)  # wait for replies to be collected
+    # wait for replies to be collected
+    time.sleep(1)
+
+    # send stop request to the rpc server
+    if isinstance(client, RPCClient) and client.is_sync:
+        client.sync_done()
 
 
-def _rpc_call(client, msg):
+def _rpc_call(client, msg, remote_method='info'):
     try:
-        res = client.call({}, 'info', message=msg)
+        res = client.call({}, remote_method, message=msg)
     except Exception as e:
         LOG.exception('Error %s on CALL for message %s', str(e), msg)
         raise
@@ -507,9 +601,9 @@ def _rpc_call(client, msg):
         return res
 
 
-def _rpc_cast(client, msg):
+def _rpc_cast(client, msg, remote_method='info'):
     try:
-        client.cast({}, 'info', message=msg)
+        client.cast({}, remote_method, message=msg)
     except Exception as e:
         LOG.exception('Error %s on CAST for message %s', str(e), msg)
         raise
@@ -663,6 +757,9 @@ def main():
     client.add_argument('--is-fanout', dest='is_fanout', action='store_true',
                         help='fanout=True for CAST messages')
 
+    client.add_argument('--sync', dest='sync', choices=('call', 'fanout'),
+                        help="stop server when all msg was sent by clients")
+
     args = parser.parse_args()
 
     _setup_logging(is_debug=args.debug)
@@ -717,14 +814,16 @@ def main():
         show_client_stats(CLIENTS, args.json_filename)
 
     elif args.mode == 'rpc-client':
-        targets = [target.partition('.')[::2] for target in args.targets]
-        targets = [messaging.Target(
-            topic=topic, server=server_name, fanout=args.is_fanout) for
-            topic, server_name in targets]
+
+        targets = []
+        for target in args.targets:
+            tp, srv = target.partition('.')[::2]
+            t = messaging.Target(topic=tp, server=srv, fanout=args.is_fanout)
+            targets.append(t)
+
         spawn_rpc_clients(args.threads, TRANSPORT, targets,
                           args.wait_after_msg, args.timeout, args.is_cast,
-                          args.messages, args.duration)
-
+                          args.messages, args.duration, args.sync)
         show_client_stats(CLIENTS, args.json_filename, not args.is_cast)
 
         if args.exit_wait:
@@ -735,4 +834,5 @@ def main():
 if __name__ == '__main__':
     RANDOM_GENERATOR = init_random_generator()
     CURRENT_PID = os.getpid()
+    CURRENT_HOST = socket.gethostname()
     main()