Moshe Vayner
Moshe Vayner

Reputation: 798

Multiple SSH Connections in a Python 2.7 script- Multiprocessing Vs Threading

I have a script that gets a list of nodes as an argument (could be 10 or even 50), and connects to each by SSH to run a service restart command. At the moment, I'm using multiprocessing in order to parallelize the script (getting the batch size as an argument as well), however I've heard that threading module could help me with performing my tasks in a quicker and easier to manage way (I'm using try..except KeyboardInterrupt with sys.exit() and pool.terminate(), but it won't stop the entire script because it's a different process). Since I understand the multithreading is more lightweight and easier to manage for my case, I am trying to convert my script to use threading instead of multiprocessing but it doesn't properly work.

The current code in multiprocessing (works):

def restart_service(node, initd_tup):
    """
    Get a node name as an argument, connect to it via SSH and run the service restart command..
    """
    command = 'service {0} restart'.format(initd_tup[node])
    logger.info('[{0}] Connecting to {0} in order to restart {1} service...'.format(node, initd_tup[node]))
    try:
        ssh.connect(node)
        stdin, stdout, stderr = ssh.exec_command(command)
        result = stdout.read()
        if not result:
            result_err = stderr.read()
            print '{0}{1}[{2}] ERROR: {3}{4}'.format(Color.BOLD, Color.RED, node, result_err, Color.END)
            logger.error('[{0}]  Result of command {1} output: {2}'.format(node, command, result_err))
        else:
            print '{0}{1}{2}[{3}]{4}\n{5}'.format(Color.BOLD, Color.UNDERLINE, Color.GREEN, node, Color.END, result)
            logger.info('[{0}]  Result of command {1} output: {2}'.format(node, command, result.replace("\n", "... ")))
        ssh.close()
    except paramiko.AuthenticationException:
        print "{0}{1}ERROR! SSH failed with Authentication Error. Make sure you run the script as root and try again..{2}".format(Color.BOLD, Color.RED, Color.END)
        logger.error('SSH Authentication failed, thrown error message to the user to make sure script is run with root permissions')
        pool.terminate()
    except socket.error as error:
        print("[{0}]{1}{2} ERROR! SSH failed with error: {3}{4}\n".format(node, Color.RED, Color.BOLD, error, Color.END))
        logger.error("[{0}] SSH failed with error: {1}".format(node, error))
    except KeyboardInterrupt:
        pool.terminate()
        general_utils.terminate(logger)


def convert_to_tuple(a_b):
    """Convert 'f([1,2])' to 'f(1,2)' call."""
    return restart_service(*a_b)


def iterate_nodes_and_call_exec_func(nodes_list):
    """
    Iterate over the list of nodes to process,
    create a list of nodes that shouldn't exceed the batch size provided (or 1 if not provided).
    Then using the multiprocessing module, call the restart_service func on x nodes in parallel (where x is the batch size).
    If batch_sleep arg was provided, call the sleep func and provide the batch_sleep argument between each batch.
    """
    global pool
    general_utils.banner('Initiating service restart')
    pool = multiprocessing.Pool(10)
    manager = multiprocessing.Manager()
    work = manager.dict()
    for line in nodes_list:
        work[line] = general_utils.get_initd(logger, args, line)
        if len(work) >= int(args.batch):
            pool.map(convert_to_tuple, itertools.izip(work.keys(), itertools.repeat(work)))
            work = {}
            if int(args.batch_sleep) > 0:
                logger.info('*** Sleeping for %d seconds before moving on to next batch ***', int(args.batch_sleep))
                general_utils.sleep_func(int(args.batch_sleep))
    if len(work) > 0:
        try:
            pool.map(convert_to_tuple, itertools.izip(work.keys(), itertools.repeat(work)))
        except KeyboardInterrupt:
            pool.terminate()
            general_utils.terminate(logger)

And here's what I've tried to to with Threading, which doesn't work (when I assign a batch_size larger than 1, the script simply gets stuck and I have to kill it forcefully.

def parse_args():
    """Define the argument parser, and the arguments to accept.."""
    global args, parser
    parser = MyParser(description=__doc__)
    parser.add_argument('-H', '--host', help='List of hosts to process, separated by "," and NO SPACES!')
    parser.add_argument('--batch', help='Do requests in batches', default=1)
    args = parser.parse_args()

    # If no arguments were passed, print the help file and exit with ERROR..
    if len(sys.argv) == 1:
        parser.print_help()
        print '\n\nERROR: No arguments passed!\n'
        sys.exit(3)


def do_work(node):
    logger.info('[{0}]'.format(node))
    try:
        ssh.connect(node)
        stdin, stdout, stderr = ssh.exec_command('hostname ; date')
        print stdout.read()
        ssh.close()
    except:
        print 'ERROR!'
        sys.exit(2)


def worker():
    while True:
        item = q.get()
        do_work(item)
        q.task_done()


def iterate():
    for item in args.host.split(","):
        q.put(item)

    for i in range(int(args.batch)):
        t = Thread(target=worker)
        t.daemon = True
        t.start()

    q.join()


def main():
    parse_args()
    try:
        iterate()

    except KeyboardInterrupt:
        exit(1)

In the script log I see a WARNING generated by Paramiko as below:

2016-01-04 22:51:37,613 WARNING: Oops, unhandled type 3

I tried to Google this unhandled type 3 error, but didn't find anything related to my issue, since it's talking about 2 factor authentication or trying to connect via both password and SSH key at the same time, but I'm only loading the host keys without providing any password to the SSH Client.

I would appreciate any help on this matter..

Upvotes: 2

Views: 5141

Answers (2)

Moshe Vayner
Moshe Vayner

Reputation: 798

In addition to using pssh module, after a more thorough troubleshooting effort, I was able to solve the the original code that I posted in the question using native Threading module, by creating a new paramiko client for every thread, rather than using the same client for all threads. So basically (only updating the do_work function from the original question), here's the change:

def do_work(node):
    logger.info('[{0}]'.format(node))
    try:
        ssh = paramiko.SSHClient() 
        ssh.connect(node)
        stdin, stdout, stderr = ssh.exec_command('hostname ; date')
        print stdout.read()
        ssh.close()
    except:
        print 'ERROR!'
        sys.exit(2)

When done this way, the native Threading module works perfectly!

Upvotes: 0

Moshe Vayner
Moshe Vayner

Reputation: 798

Managed to solve my problem using parallel-ssh module.

Here's the code, fixed with my desired actions:

def iterate_nodes_and_call_exec_func(nodes):
    """
    Get a dict as an argument, containing linux services (initd) as the keys,
    and a list of nodes on which the linux service needs to be checked/
    Iterate over the list of nodes to process,
    create a list of nodes that shouldn't exceed the batch size provided (or 1 if not provided).
    Then using the parallel-ssh module, call the restart_service func on x nodes in parallel (where x is the batch size)
    and provide the linux service (initd) to process.
    If batch_sleep arg was provided, call the sleep func and provide the batch_sleep argument between each batch.
    """

    for initd in nodes.keys():
        work = dict()
        work[initd] = []
        count = 0
        for node in nodes[initd]:
            count += 1
            work[initd].append(node)
            if len(work[initd]) == args.batch:
                restart_service(work[initd], initd)
                work[initd] = []
                if args.batch_sleep > 0 and count < len(nodes[initd]):
                    logger.info('*** Sleeping for %d seconds before moving on to next batch ***', args.batch_sleep)
                    general_utils.sleep_func(int(args.batch_sleep))
        if len(work[initd]) > 0:
            restart_service(work[initd], initd)


def restart_service(nodes, initd):
    """
    Get a list of nodes and linux service as an argument,
    then connect by Parallel SSH module to the nodes and run the service restart command..
    """
    command = 'service {0} restart'.format(initd)
    logger.info('Connecting to {0} to restart the {1} service...'.format(nodes, initd))
    try:
        client = pssh.ParallelSSHClient(nodes, pool_size=args.batch, timeout=10, num_retries=1)
        output = client.run_command(command, sudo=True)
        for node in output:
            for line in output[node]['stdout']:
                if client.get_exit_code(output[node]) == 0:
                    print '[{0}]{1}{2}  {3}{4}'.format(node, Color.BOLD, Color.GREEN, line, Color.END)
                else:
                    print '[{0}]{1}{2}  ERROR! {3}{4}'.format(node, Color.BOLD, Color.RED, line, Color.END)
                    logger.error('[{0}]  Result of command {1} output: {2}'.format(node, command, line))

    except pssh.AuthenticationException:
        print "{0}{1}ERROR! SSH failed with Authentication Error. Make sure you run the script as root and try again..{2}".format(Color.BOLD, Color.RED, Color.END)
        logger.error('SSH Authentication failed, thrown error message to the user to make sure script is run with root permissions')
        sys.exit(2)

    except pssh.ConnectionErrorException as error:
        print("[{0}]{1}{2} ERROR! SSH failed with error: {3}{4}\n".format(error[1], Color.RED, Color.BOLD, error[3], Color.END))
        logger.error("[{0}] SSH Failed with error: {1}".format(error[1], error[3]))
        restart_service(nodes[nodes.index(error[1])+1:], initd)

    except KeyboardInterrupt:
        general_utils.terminate(logger)


def generate_nodes_by_initd_dict(nodes_list):
    """
    Get a list of nodes as an argument.
    Then by calling the get_initd function for each of the nodes,
    Build a dict based on linux services (initd) as keys and a list of nodes on which the initd
    needs to be processed as values. Then call the iterate_nodes_and_call_exec_func and provide the generated dict
     as its argument.
    """
    nodes = {}
    for node in nodes_list:
        initd = general_utils.get_initd(logger, args, node)
        if initd in nodes.keys():
            nodes[initd].append(node)
        else:
            nodes[initd] = [node, ]

    return iterate_nodes_and_call_exec_func(nodes)


def main():
    parse_args()
    try:
        general_utils.init_script('Service Restart', logger, log)
        log_args(logger, args)
        generate_nodes_by_initd_dict(general_utils.generate_nodes_list(args, logger, ['service', 'datacenter', 'lob']))

    except KeyboardInterrupt:
        general_utils.terminate(logger)

    finally:
        general_utils.wrap_up(logger)


if __name__ == '__main__':
    main()

Upvotes: 3

Related Questions