diff --git a/CHANGELOG.md b/CHANGELOG.md index 9eda5dd04..a1141a971 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ * Bugfix - Fix Python 3.10 compatibility (#983) PR #972 * Bugfix - Allow renaming non-conforming attributes in proj (#982) PR #972 * Add - Expose proxy feature for S3 external stores (#961) PR #962 +* Add - implement multiprocessing in populate (#695) PR #704, #969 * Bugfix - Dependencies not properly loaded on populate. (#902) PR #919 * Bugfix - Replace use of numpy aliases of built-in types with built-in type. (#938) PR #939 * Bugfix - `ExternalTable.delete` should not remove row on error (#953) PR #956 diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 7802f8288..3aa9e78a8 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -8,12 +8,37 @@ from .expression import QueryExpression, AndList from .errors import DataJointError, LostConnectionError import signal +import multiprocessing as mp # noinspection PyExceptionInherit,PyCallingNonCallable logger = logging.getLogger(__name__) +# --- helper functions for multiprocessing -- + +def _initialize_populate(table, jobs, populate_kwargs): + """ + Initialize the process for mulitprocessing. + Saves the unpickled copy of the table to the current process and reconnects. + """ + process = mp.current_process() + process.table = table + process.jobs = jobs + process.populate_kwargs = populate_kwargs + table.connection.connect() # reconnect + + +def _call_populate1(key): + """ + Call current process' table._populate1() + :key - a dict specifying job to compute + :return: key, error if error, otherwise None + """ + process = mp.current_process() + return process.table._populate1(key, process.jobs, **process.populate_kwargs) + + class AutoPopulate: """ AutoPopulate is a mixin class that adds the method populate() to a Relation class. @@ -28,8 +53,9 @@ def key_source(self): """ :return: the query expression that yields primary key values to be passed, sequentially, to the ``make`` method when populate() is called. - The default value is the join of the parent relations. - Users may override to change the granularity or the scope of populate() calls. + The default value is the join of the parent tables references from the primary key. + Subclasses may override they key_source to change the scope or the granularity + of the make calls. """ def _rename_attributes(table, props): return (table.proj( @@ -96,18 +122,20 @@ def _jobs_to_do(self, restrictions): def populate(self, *restrictions, suppress_errors=False, return_exception_objects=False, reserve_jobs=False, order="original", limit=None, max_calls=None, - display_progress=False): + display_progress=False, processes=1): """ - rel.populate() calls rel.make(key) for every primary key in self.key_source - for which there is not already a tuple in rel. - :param restrictions: a list of restrictions each restrict (rel.key_source - target.proj()) + table.populate() calls table.make(key) for every primary key in self.key_source + for which there is not already a tuple in table. + :param restrictions: a list of restrictions each restrict (table.key_source - target.proj()) :param suppress_errors: if True, do not terminate execution. :param return_exception_objects: return error objects instead of just error messages - :param reserve_jobs: if true, reserves job to populate in asynchronous fashion + :param reserve_jobs: if True, reserve jobs to populate in asynchronous fashion :param order: "original"|"reverse"|"random" - the order of execution + :param limit: if not None, check at most this many keys + :param max_calls: if not None, populate at most this many keys :param display_progress: if True, report progress_bar - :param limit: if not None, checks at most that many keys - :param max_calls: if not None, populates at max that many keys + :param processes: number of processes to use. When set to a large number, then + uses as many as CPU cores """ if self.connection.in_transaction: raise DataJointError('Populate cannot be called during a transaction.') @@ -115,10 +143,9 @@ def populate(self, *restrictions, suppress_errors=False, return_exception_object valid_order = ['original', 'reverse', 'random'] if order not in valid_order: raise DataJointError('The order argument must be one of %s' % str(valid_order)) - error_list = [] if suppress_errors else None jobs = self.connection.schemas[self.target.database].jobs if reserve_jobs else None - # define and setup signal handler for SIGTERM + # define and set up signal handler for SIGTERM: if reserve_jobs: def handler(signum, frame): logger.info('Populate terminated by SIGTERM') @@ -131,60 +158,99 @@ def handler(signum, frame): elif order == "random": random.shuffle(keys) - call_count = 0 logger.info('Found %d keys to populate' % len(keys)) - make = self._make_tuples if hasattr(self, '_make_tuples') else self.make + keys = keys[:max_calls] + nkeys = len(keys) - for key in (tqdm(keys, desc=self.__class__.__name__) if display_progress else keys): - if max_calls is not None and call_count >= max_calls: - break - if not reserve_jobs or jobs.reserve(self.target.table_name, self._job_key(key)): - self.connection.start_transaction() - if key in self.target: # already populated - self.connection.cancel_transaction() - if reserve_jobs: - jobs.complete(self.target.table_name, self._job_key(key)) + if processes > 1: + processes = min(processes, nkeys, mp.cpu_count()) + + error_list = [] + populate_kwargs = dict( + suppress_errors=suppress_errors, + return_exception_objects=return_exception_objects) + + if processes == 1: + for key in tqdm(keys, desc=self.__class__.__name__) if display_progress else keys: + error = self._populate1(key, jobs, **populate_kwargs) + if error is not None: + error_list.append(error) + else: + # spawn multiple processes + self.connection.close() # disconnect parent process from MySQL server + del self.connection._conn.ctx # SSLContext is not pickleable + with mp.Pool(processes, _initialize_populate, (self, populate_kwargs)) as pool: + if display_progress: + with tqdm(desc="Processes: ", total=nkeys) as pbar: + for error in pool.imap(_call_populate1, keys, chunksize=1): + if error is not None: + error_list.append(error) + pbar.update() else: - logger.info('Populating: ' + str(key)) - call_count += 1 - self.__class__._allow_insert = True - try: - make(dict(key)) - except (KeyboardInterrupt, SystemExit, Exception) as error: - try: - self.connection.cancel_transaction() - except LostConnectionError: - pass - error_message = '{exception}{msg}'.format( - exception=error.__class__.__name__, - msg=': ' + str(error) if str(error) else '') - if reserve_jobs: - # show error name and error message (if any) - jobs.error( - self.target.table_name, self._job_key(key), - error_message=error_message, error_stack=traceback.format_exc()) - if not suppress_errors or isinstance(error, SystemExit): - raise - else: - logger.error(error) - error_list.append((key, error if return_exception_objects else error_message)) - else: - self.connection.commit_transaction() - if reserve_jobs: - jobs.complete(self.target.table_name, self._job_key(key)) - finally: - self.__class__._allow_insert = False + for error in pool.imap(_call_populate1, keys): + if error is not None: + error_list.append(error) + self.connection.connect() # reconnect parent process to MySQL server - # place back the original signal handler + # restore original signal handler: if reserve_jobs: signal.signal(signal.SIGTERM, old_handler) - return error_list + + if suppress_errors: + return error_list + + def _populate1(self, key, jobs, suppress_errors, return_exception_objects): + """ + populates table for one source key, calling self.make inside a transaction. + :param jobs: the jobs table or None if not reserve_jobs + :param key: dict specifying job to populate + :param suppress_errors: bool if errors should be suppressed and returned + :param return_exception_objects: if True, errors must be returned as objects + :return: (key, error) when suppress_errors=True, otherwise None + """ + make = self._make_tuples if hasattr(self, '_make_tuples') else self.make + + if jobs is None or jobs.reserve(self.target.table_name, self._job_key(key)): + self.connection.start_transaction() + if key in self.target: # already populated + self.connection.cancel_transaction() + if jobs is not None: + jobs.complete(self.target.table_name, self._job_key(key)) + else: + logger.info('Populating: ' + str(key)) + self.__class__._allow_insert = True + try: + make(dict(key)) + except (KeyboardInterrupt, SystemExit, Exception) as error: + try: + self.connection.cancel_transaction() + except LostConnectionError: + pass + error_message = '{exception}{msg}'.format( + exception=error.__class__.__name__, + msg=': ' + str(error) if str(error) else '') + if jobs is not None: + # show error name and error message (if any) + jobs.error( + self.target.table_name, self._job_key(key), + error_message=error_message, error_stack=traceback.format_exc()) + if not suppress_errors or isinstance(error, SystemExit): + raise + else: + logger.error(error) + return key, error if return_exception_objects else error_message + else: + self.connection.commit_transaction() + if jobs is not None: + jobs.complete(self.target.table_name, self._job_key(key)) + finally: + self.__class__._allow_insert = False def progress(self, *restrictions, display=True): """ - report progress of populating the table - :return: remaining, total -- tuples to be populated + Report the progress of populating the table. + :return: (remaining, total) -- numbers of tuples to be populated """ todo = self._jobs_to_do(restrictions) total = len(todo) diff --git a/datajoint/blob.py b/datajoint/blob.py index be348ef62..d3837cb6a 100644 --- a/datajoint/blob.py +++ b/datajoint/blob.py @@ -166,6 +166,8 @@ def pack_blob(self, obj): return self.pack_array(np.array(obj)) if isinstance(obj, (bool, np.bool_)): return self.pack_array(np.array(obj)) + if isinstance(obj, (float, int, complex)): + return self.pack_array(np.array(obj)) if isinstance(obj, (datetime.datetime, datetime.date, datetime.time)): return self.pack_datetime(obj) if isinstance(obj, Decimal): diff --git a/datajoint/connection.py b/datajoint/connection.py index cc4961adb..3daac4bac 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -278,7 +278,7 @@ def query(self, query, args=(), *, as_dict=False, suppress_warnings=True, reconn # check cache first: use_query_cache = bool(self._query_cache) if use_query_cache and not re.match(r"\s*(SELECT|SHOW)", query): - raise errors.DataJointError("Only SELECT query are allowed when query caching is on.") + raise errors.DataJointError("Only SELECT queries are allowed when query caching is on.") if use_query_cache: if not config['query_cache']: raise errors.DataJointError("Provide filepath dj.config['query_cache'] when using query caching.") diff --git a/datajoint/expression.py b/datajoint/expression.py index d2277c7c7..85dce57e6 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -37,7 +37,7 @@ class QueryExpression: """ _restriction = None _restriction_attributes = None - _left = [] # True for left joins, False for inner joins + _left = [] # list of booleans True for left joins, False for inner joins _original_heading = None # heading before projections # subclasses or instantiators must provide values @@ -263,7 +263,7 @@ def join(self, other, semantic_check=True, left=False): if semantic_check: assert_join_compatibility(self, other) join_attributes = set(n for n in self.heading.names if n in other.heading.names) - # needs subquery if FROM class has common attributes with the other's FROM clause + # needs subquery if self's FROM clause has common attributes with other's FROM clause need_subquery1 = need_subquery2 = bool( (set(self.original_heading.names) & set(other.original_heading.names)) - join_attributes) @@ -306,7 +306,7 @@ def proj(self, *attributes, **named_attributes): self.proj(...) or self.proj(Ellipsis) -- include all attributes (return self) self.proj() -- include only primary key self.proj('attr1', 'attr2') -- include primary key and attributes attr1 and attr2 - self.proj(..., '-attr1', '-attr2') -- include attributes except attr1 and attr2 + self.proj(..., '-attr1', '-attr2') -- include all attributes except attr1 and attr2 self.proj(name1='attr1') -- include primary key and 'attr1' renamed as name1 self.proj('attr1', dup='(attr1)') -- include primary key and attribute attr1 twice, with the duplicate 'dup' self.proj(k='abs(attr1)') adds the new attribute k with the value computed as an expression (SQL syntax) diff --git a/docs-parts/intro/Releases_lang1.rst b/docs-parts/intro/Releases_lang1.rst index d4ea88ae9..c87234aab 100644 --- a/docs-parts/intro/Releases_lang1.rst +++ b/docs-parts/intro/Releases_lang1.rst @@ -4,6 +4,7 @@ * Bugfix - Fix Python 3.10 compatibility (#983) PR #972 * Bugfix - Allow renaming non-conforming attributes in proj (#982) PR #972 * Add - Expose proxy feature for S3 external stores (#961) PR #962 +* Add - implement multiprocessing in populate (#695) PR #704, #969 * Bugfix - Dependencies not properly loaded on populate. (#902) PR #919 * Bugfix - Replace use of numpy aliases of built-in types with built-in type. (#938) PR #939 * Bugfix - `ExternalTable.delete` should not remove row on error (#953) PR #956 diff --git a/tests/test_autopopulate.py b/tests/test_autopopulate.py index 081787670..1875a6743 100644 --- a/tests/test_autopopulate.py +++ b/tests/test_autopopulate.py @@ -8,7 +8,6 @@ class TestPopulate: """ Test base relations: insert, delete """ - def setUp(self): self.user = schema.User() self.subject = schema.Subject() @@ -53,7 +52,7 @@ def test_populate(self): def test_allow_direct_insert(self): assert_true(self.subject, 'root tables are empty') - key = self.subject.fetch('KEY')[0] + key = self.subject.fetch('KEY', limit=1)[0] key['experiment_id'] = 1000 key['experiment_date'] = '2018-10-30' self.experiment.insert1(key, allow_direct_insert=True) diff --git a/tests/test_blob.py b/tests/test_blob.py index 225fb775c..61e60a8d0 100644 --- a/tests/test_blob.py +++ b/tests/test_blob.py @@ -23,6 +23,9 @@ def test_pack(): x = np.random.randn(10) assert_array_equal(x, unpack(pack(x)), "Arrays do not match!") + x = 7j + assert_equal(x, unpack(pack(x)), "Complex scalar does not match") + x = np.float32(np.random.randn(3, 4, 5)) assert_array_equal(x, unpack(pack(x)), "Arrays do not match!")