aboutsummaryrefslogtreecommitdiff
path: root/plip/basic/parallel.py
blob: cd0b93b799395353fc16d7b52986bb416e380a6b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import itertools
import multiprocessing
from builtins import zip
from functools import partial

from numpy import asarray


class SubProcessError(Exception):
    def __init__(self, e, exitcode=1):
        self.exitcode = exitcode
        super(SubProcessError, self).__init__(e)

    pass


def universal_worker(input_pair):
    """This is a wrapper function expecting a tiplet of function, single
       argument, dict of keyword arguments. The provided function is called
       with the appropriate arguments."""
    function, arg, kwargs = input_pair
    return function(arg, **kwargs)


def pool_args(function, sequence, kwargs):
    """Return a single iterator of n elements of lists of length 3, given a sequence of len n."""
    return zip(itertools.repeat(function), sequence, itertools.repeat(kwargs))


def parallel_fn(f):
    """Simple wrapper function, returning a parallel version of the given function f.
       The function f must have one argument and may have an arbitray number of
       keyword arguments. """

    def simple_parallel(func, sequence, **args):
        """ f takes an element of sequence as input and the keyword args in **args"""
        if "processes" in args:
            processes = args.get("processes")
            del args["processes"]
        else:
            processes = multiprocessing.cpu_count()

        pool = multiprocessing.Pool(processes)  # depends on available cores

        result = pool.map_async(universal_worker, pool_args(func, sequence, args))
        pool.close()
        pool.join()
        cleaned = [x for x in result.get() if x is not None]  # getting results
        cleaned = asarray(cleaned)
        return cleaned

    return partial(simple_parallel, f)