# Parallel processing

import cPickle
import itertools
import Queue
import rand
import types
import warnings

import collection
import dicts
import exc
from Runnable import Runnable

def try_pickle(value):
    try: cPickle.dumps(value)
    except Exception, e:
        exc.add_msg(e, 'Tried to pickle: '+repr(value))
        raise

def prepickle(value, vars_id_dict_):
    def filter_(value, is_leaf):
        id_ = id(value)
        if id_ in vars_id_dict_: value = id_
        # Try pickling the value. If it fails, we'll get a full traceback here,
        # which is not provided with pickling errors in multiprocessing's Pool.
        elif is_leaf: try_pickle(value)
        return value
    return collection.rmap(filter_, value)

def post_unpickle(value, vars_id_dict_):
    def filter_(value, is_leaf):
        if type(value) == int: value = vars_id_dict_.get(value, value)
            # get() returns the value itself if it isn't a known id()
        return value
    return collection.rmap(filter_, value)

class SyncPool:
    '''A dummy synchronous Pool to use if multiprocessing is not available'''
    def __init__(self, processes=None): pass
    
    class Result:
        def __init__(self, value): self.value = value
        
        def get(timeout=None): return self.value
        
        def wait(timeout=None): pass
        
        def ready(): return True
        
        def successful(): return True # TODO: False if raised exception
    
    def apply_async(self, func, args=(), kw_args={}, callback=None):
        if callback == None: callback = lambda v: None
        
        value = func(*args, **kw_args)
        callback(value)
        return self.Result(value)

class MultiProducerPool:
    '''A multi-producer pool. You must call pool.main_loop() in the thread that
    created this to process new tasks.'''
    
    def __init__(self, processes=None, locals_={}, *shared):
        '''
        @param processes If 0, uses SyncPool
        @post The # processes actually used is made available in self.process_ct
        '''
        try:
            if processes == 0: raise ImportError('turned off')
            import multiprocessing
            import multiprocessing.pool
        except ImportError, e:
            warnings.warn(UserWarning('Not using parallel processing: '+str(e)))
            processes = 1
            Pool_ = SyncPool
            Queue_ = Queue.Queue
        else:
            if processes == None: processes = multiprocessing.cpu_count()
            Pool_ = multiprocessing.pool.Pool
            Queue_ = multiprocessing.Queue
        
        self.process_ct = processes
        self.pool = Pool_(processes)
        self.queue = Queue_()
        self.active_tasks = 0
        
        # Store a reference to the manager in self, because it will otherwise be
        # shutdown right away when it goes out of scope
        #self.manager = processing.Manager()
        #self.shared_rw = self.manager.Namespace()
        
        # Values that may be pickled by id()
        self.vars_id_dict = dicts.IdDict()
        self.share(self, *shared).share_vars(locals_).share_vars(globals())
    
    def share(self, *values):
        '''Call this on all values that should be shared writably between all
        processes (and be pickled by id())'''
        self.vars_id_dict.add(*values)
        return self
    
    def share_vars(self, vars_):
        '''Call this on all vars that should be pickled by id().
        Usage: self.share_vars(locals())
        @param vars_ {var_name: value}
        '''
        self.vars_id_dict.add_vars(vars_)
        return self
    
    def main_loop(self):
        '''Prime the task queue with at least one task before calling this''' 
        while True:
            try: call = self.queue.get(timeout=0.1) # sec
            except Queue.Empty:
                if self.active_tasks == 0: break # program done
                else: continue
            
            def handle_result(*args, **kw_args):
                self.active_tasks -= 1
                if call.callback != None: call.callback(*args, **kw_args)
            
            self.active_tasks += 1
            self.pool.apply_async(call.func, self.post_unpickle(call.args),
                self.post_unpickle(call.kw_args), handle_result)
    
    class Result:
        def get(timeout=None): raise NotImplementedError()
        
        def wait(timeout=None): raise NotImplementedError()
        
        def ready(): raise NotImplementedError()
        
        def successful(): raise NotImplementedError()
    
    def apply_async(self, func, args=(), kw_args={}, callback=None):
        assert callback == None, 'Callbacks not supported'
        
        call = Runnable(func, *self.prepickle(args), **self.prepickle(kw_args))
        call.callback = callback # store this inside the Runnable
        
        self.queue.put_nowait(call)
        return self.Result()
    
    def prepickle(self, value): return prepickle(value, self.vars_id_dict)
    
    def post_unpickle(self, value):
        return post_unpickle(value, self.vars_id_dict)
