Project

General

Profile

« Previous | Next » 

Revision 1885

dicts.py: Turned id_dict() factory function into IdDict class. parallel.py: MultiProducerPool: Added share_vars(). main_loop(): Only consider the program to be done if the queue is empty and there are no running tasks.

View differences:

lib/parallel.py
18 18
        exc.add_msg(e, 'Tried to pickle: '+repr(value))
19 19
        raise
20 20

  
21
def vars_id_dict(locals_, globals_, *misc):
22
    '''Usage: vars_id_dict(locals(), globals(), misc...)'''
23
    vars_ = map(lambda v: v.values(), [locals_, globals_]) + list(misc)
24
    return dicts.id_dict(vars_)
25

  
26 21
def prepickle(value, vars_id_dict_):
27
    def filter_(value):
22
    def filter_(value, is_leaf):
28 23
        id_ = id(value)
29 24
        if id_ in vars_id_dict_: value = id_
30 25
        # Try pickling the value. If it fails, we'll get a full traceback here,
31 26
        # which is not provided with pickling errors in multiprocessing's Pool.
32
        else: try_pickle(value)
27
        elif is_leaf: try_pickle(value)
33 28
        return value
34 29
    return collection.rmap(filter_, value)
35 30

  
36 31
def post_unpickle(value, vars_id_dict_):
37
    def filter_(value):
38
        try: return vars_id_dict_[value] # value is an id()
39
        except KeyError: return value
32
    def filter_(value, is_leaf):
33
        if type(value) == int: value = vars_id_dict_.get(value, value)
34
            # get() returns the value itself if it isn't a known id()
35
        return value
40 36
    return collection.rmap(filter_, value)
41 37

  
42 38
class SyncPool:
......
65 61
    '''A multi-producer pool. You must call pool.main_loop() in the thread that
66 62
    created this to process new tasks.'''
67 63
    
68
    def __init__(self, processes=None, locals_=None, globals_=None, *shared):
64
    def __init__(self, processes=None, locals_={}, *shared):
69 65
        '''
70 66
        @param processes If 0, uses SyncPool
71 67
        @post The # processes actually used is made available in self.process_ct
72 68
        '''
73
        if locals_ == None: locals_ = locals()
74
        if globals_ == None: globals_ = globals()
75
        
76 69
        try:
77 70
            if processes == 0: raise ImportError('turned off')
78 71
            import multiprocessing
......
90 83
        self.process_ct = processes
91 84
        self.pool = Pool_(processes)
92 85
        self.queue = Queue_()
86
        self.active_tasks = 0
87
        
93 88
        # Values that may be pickled by id()
94
        self.vars_id_dict = vars_id_dict(locals_, globals_, *shared)
89
        self.vars_id_dict = dicts.IdDict()
90
        self.share(self, *shared).share_vars(locals_).share_vars(globals())
95 91
    
96
    def share(self, value):
97
        '''Call this on every value that that may be pickled by id()'''
98
        self.vars_id_dict[id(value)] = value
92
    def share(self, *values):
93
        '''Call this on all values that that should be pickled by id()'''
94
        self.vars_id_dict.add(*values)
95
        return self
99 96
    
97
    def share_vars(self, vars_):
98
        '''Call this on all vars that that should be pickled by id().
99
        Usage: self.share_vars(locals())
100
        '''
101
        self.vars_id_dict.add_vars(vars_)
102
        return self
103
    
100 104
    def main_loop(self):
101
        try:
102
            while True:
103
                # block=False raises Empty immediately if the queue is empty,
104
                # which indicates that the program is done
105
                call = self.queue.get(block=False)
106
                self.pool.apply_async(call.func, self.post_unpickle(call.args),
107
                    self.post_unpickle(call.kw_args), call.callback)
108
        except Queue.Empty: pass
105
        '''Prime the task queue with at least one task before calling this''' 
106
        while True:
107
            try: call = self.queue.get(timeout=0.1) # sec
108
            except Queue.Empty:
109
                if self.active_tasks == 0: break # program done
110
                else: continue
111
            
112
            def handle_result(*args, **kw_args):
113
                self.active_tasks -= 1
114
                if call.callback != None: call.callback(*args, **kw_args)
115
            
116
            self.active_tasks += 1
117
            self.pool.apply_async(call.func, self.post_unpickle(call.args),
118
                self.post_unpickle(call.kw_args), handle_result)
109 119
    
110 120
    class Result:
111 121
        def get(timeout=None): raise NotImplementedError()
lib/dicts.py
1 1
# Dictionaries
2 2

  
3
def id_dict(objects=[]):
4
    '''Makes a dict of objects by id() value'''
5
    return dict(((id(v), v) for v in objects))
3
import itertools
6 4

  
5
class IdDict(dict):
6
    '''A dict that stores objects by id()'''
7
    
8
    def add(self, *values):
9
        for value in values: self[id(value)] = value
10
        return self
11
    
12
    def add_vars(self, vars_): return self.add(*vars_.values())
13

  
7 14
class MergeDict:
8 15
    '''A dict that checks each of several dicts'''
16
    
9 17
    def __init__(self, *dicts): self.dicts = dicts
10 18
    
11 19
    def __getitem__(self, key):

Also available in: Unified diff