Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
9822fc4
Add multiprocessing to AutoPopulate
mspacek Nov 18, 2019
6074900
Update docs
mspacek Nov 18, 2019
a7e4c2e
Rename max_processes -> multiprocess, accept bool or int
mspacek Nov 18, 2019
24de484
Fix reserved jobs
mspacek Nov 18, 2019
e2f2b18
Merge pull request #707 from datajoint/master
dimitri-yatsenko Nov 20, 2019
28aae7a
Use is instead of ==
mspacek Nov 21, 2019
236e62e
Replace assertion with DataJointError
mspacek Nov 21, 2019
9266668
Remove extra blank line
mspacek Nov 21, 2019
4f30328
fix #700
dimitri-yatsenko Nov 17, 2019
8d7d7d3
fix #700 differently - bypass the restriction against python-native d…
dimitri-yatsenko Nov 17, 2019
3e3d688
fix typo from previous commit
dimitri-yatsenko Nov 18, 2019
ad02fe1
tests/{schema,test_jobs}.py: add tests/test_jobs.py:test_suppress_dj_…
Nov 19, 2019
53d073f
cleanup
dimitri-yatsenko Nov 19, 2019
df27ada
minor syntax improvement
dimitri-yatsenko Nov 19, 2019
8f46e05
minor syntax
dimitri-yatsenko Nov 19, 2019
6ed83c6
fix #675
dimitri-yatsenko Nov 19, 2019
91556d5
Fix #698
dimitri-yatsenko Nov 19, 2019
2c8aedf
fix #699 -- add table definition to doc string
dimitri-yatsenko Nov 19, 2019
bc69123
Table doc strings now display reverse-engineered table declarations
dimitri-yatsenko Nov 19, 2019
1034797
update error message for un-upgraded external stores.
dimitri-yatsenko Nov 19, 2019
d934b7e
improve table definition in the doc string
dimitri-yatsenko Nov 19, 2019
2b8588c
minor improvement in display of table doc strings
dimitri-yatsenko Nov 19, 2019
c7b7989
replace .describe() with .definition to augment the table docstring
dimitri-yatsenko Nov 20, 2019
ca775da
improve unit test for definition in docstring
dimitri-yatsenko Nov 21, 2019
b646061
minor
dimitri-yatsenko Nov 21, 2019
60c952b
update CHANGELOG for version 0.12.3
dimitri-yatsenko Nov 22, 2019
95c2249
update docs for release 0.12.3
dimitri-yatsenko Nov 22, 2019
b3ee8d9
blob now accepts native complex scalars
dimitri-yatsenko Nov 22, 2019
3e6b174
Merge pull request #704 from mspacek/mp
dimitri-yatsenko Dec 16, 2019
fa59df2
Merge branch 'mp' of https://github.com/datajoint/datajoint-python in…
dimitri-yatsenko Jan 24, 2020
bb9351c
add OVERVIEW.md
dimitri-yatsenko Feb 13, 2021
2d49377
minor edits in OVERVIEW.md
dimitri-yatsenko Feb 13, 2021
80b9de8
Merge branch 'cascade-delete' into r013-docs
dimitri-yatsenko Mar 12, 2021
52f1f70
minor error message wording
dimitri-yatsenko Mar 13, 2021
e45e922
comment and docstring corrections
dimitri-yatsenko Mar 13, 2021
0409ffd
Merge branch 'datajoint:master' into r013-docs
dimitri-yatsenko Sep 25, 2021
f6a5072
Merge branch 'master' into mp
dimitri-yatsenko Sep 25, 2021
cc144a6
Merge branch 'master' into mp
dimitri-yatsenko Sep 26, 2021
0243baf
Merge branch 'r013-docs' of https://github.com/dimitri-yatsenko/dataj…
dimitri-yatsenko Sep 26, 2021
fbdef21
minor
dimitri-yatsenko Sep 26, 2021
6fea058
add OVERVIEW.md
dimitri-yatsenko Sep 26, 2021
c641029
remove OVERVIEW.md
dimitri-yatsenko Sep 26, 2021
b52a130
doc string improvements in autopopulate
dimitri-yatsenko Sep 26, 2021
0079a7e
minor cleanup in autopopulate
dimitri-yatsenko Sep 27, 2021
85520ed
minor PEP8
dimitri-yatsenko Oct 7, 2021
88c634b
Merge branch 'master' of https://github.com/datajoint/datajoint-pytho…
dimitri-yatsenko Oct 8, 2021
54a61b8
Merge branch 'master' into mp
dimitri-yatsenko Jan 19, 2022
ae59587
Merge branch 'master' into mp
dimitri-yatsenko Jan 19, 2022
ba3039c
Merge branch 'master' into mp
dimitri-yatsenko Jan 19, 2022
a661d19
Merge branch 'master' into mp
dimitri-yatsenko Jan 20, 2022
ab691d3
update CHANGELOG to include multiprocessing
dimitri-yatsenko Jan 20, 2022
9ba4e86
whitespace
dimitri-yatsenko Jan 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* Bugfix - Fix Python 3.10 compatibility (#983) PR #972
* Bugfix - Allow renaming non-conforming attributes in proj (#982) PR #972
* Add - Expose proxy feature for S3 external stores (#961) PR #962
* Add - implement multiprocessing in populate (#695) PR #704, #969
* Bugfix - Dependencies not properly loaded on populate. (#902) PR #919
* Bugfix - Replace use of numpy aliases of built-in types with built-in type. (#938) PR #939
* Bugfix - `ExternalTable.delete` should not remove row on error (#953) PR #956
Expand Down
176 changes: 121 additions & 55 deletions datajoint/autopopulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,37 @@
from .expression import QueryExpression, AndList
from .errors import DataJointError, LostConnectionError
import signal
import multiprocessing as mp

# noinspection PyExceptionInherit,PyCallingNonCallable

logger = logging.getLogger(__name__)


# --- helper functions for multiprocessing --

def _initialize_populate(table, jobs, populate_kwargs):
"""
Initialize the process for mulitprocessing.
Saves the unpickled copy of the table to the current process and reconnects.
"""
process = mp.current_process()
process.table = table
process.jobs = jobs
process.populate_kwargs = populate_kwargs
table.connection.connect() # reconnect


def _call_populate1(key):
"""
Call current process' table._populate1()
:key - a dict specifying job to compute
:return: key, error if error, otherwise None
"""
process = mp.current_process()
return process.table._populate1(key, process.jobs, **process.populate_kwargs)


class AutoPopulate:
"""
AutoPopulate is a mixin class that adds the method populate() to a Relation class.
Expand All @@ -28,8 +53,9 @@ def key_source(self):
"""
:return: the query expression that yields primary key values to be passed,
sequentially, to the ``make`` method when populate() is called.
The default value is the join of the parent relations.
Users may override to change the granularity or the scope of populate() calls.
The default value is the join of the parent tables references from the primary key.
Subclasses may override they key_source to change the scope or the granularity
of the make calls.
"""
def _rename_attributes(table, props):
return (table.proj(
Expand Down Expand Up @@ -96,29 +122,30 @@ def _jobs_to_do(self, restrictions):

def populate(self, *restrictions, suppress_errors=False, return_exception_objects=False,
reserve_jobs=False, order="original", limit=None, max_calls=None,
display_progress=False):
display_progress=False, processes=1):
"""
rel.populate() calls rel.make(key) for every primary key in self.key_source
for which there is not already a tuple in rel.
:param restrictions: a list of restrictions each restrict (rel.key_source - target.proj())
table.populate() calls table.make(key) for every primary key in self.key_source
for which there is not already a tuple in table.
:param restrictions: a list of restrictions each restrict (table.key_source - target.proj())
:param suppress_errors: if True, do not terminate execution.
:param return_exception_objects: return error objects instead of just error messages
:param reserve_jobs: if true, reserves job to populate in asynchronous fashion
:param reserve_jobs: if True, reserve jobs to populate in asynchronous fashion
:param order: "original"|"reverse"|"random" - the order of execution
:param limit: if not None, check at most this many keys
:param max_calls: if not None, populate at most this many keys
:param display_progress: if True, report progress_bar
:param limit: if not None, checks at most that many keys
:param max_calls: if not None, populates at max that many keys
:param processes: number of processes to use. When set to a large number, then
uses as many as CPU cores
"""
if self.connection.in_transaction:
raise DataJointError('Populate cannot be called during a transaction.')

valid_order = ['original', 'reverse', 'random']
if order not in valid_order:
raise DataJointError('The order argument must be one of %s' % str(valid_order))
error_list = [] if suppress_errors else None
jobs = self.connection.schemas[self.target.database].jobs if reserve_jobs else None

# define and setup signal handler for SIGTERM
# define and set up signal handler for SIGTERM:
if reserve_jobs:
def handler(signum, frame):
logger.info('Populate terminated by SIGTERM')
Expand All @@ -131,60 +158,99 @@ def handler(signum, frame):
elif order == "random":
random.shuffle(keys)

call_count = 0
logger.info('Found %d keys to populate' % len(keys))

make = self._make_tuples if hasattr(self, '_make_tuples') else self.make
keys = keys[:max_calls]
nkeys = len(keys)

for key in (tqdm(keys, desc=self.__class__.__name__) if display_progress else keys):
if max_calls is not None and call_count >= max_calls:
break
if not reserve_jobs or jobs.reserve(self.target.table_name, self._job_key(key)):
self.connection.start_transaction()
if key in self.target: # already populated
self.connection.cancel_transaction()
if reserve_jobs:
jobs.complete(self.target.table_name, self._job_key(key))
if processes > 1:
processes = min(processes, nkeys, mp.cpu_count())

error_list = []
populate_kwargs = dict(
suppress_errors=suppress_errors,
return_exception_objects=return_exception_objects)

if processes == 1:
for key in tqdm(keys, desc=self.__class__.__name__) if display_progress else keys:
error = self._populate1(key, jobs, **populate_kwargs)
if error is not None:
error_list.append(error)
else:
# spawn multiple processes
self.connection.close() # disconnect parent process from MySQL server
del self.connection._conn.ctx # SSLContext is not pickleable
with mp.Pool(processes, _initialize_populate, (self, populate_kwargs)) as pool:
if display_progress:
with tqdm(desc="Processes: ", total=nkeys) as pbar:
for error in pool.imap(_call_populate1, keys, chunksize=1):
if error is not None:
error_list.append(error)
pbar.update()
else:
logger.info('Populating: ' + str(key))
call_count += 1
self.__class__._allow_insert = True
try:
make(dict(key))
except (KeyboardInterrupt, SystemExit, Exception) as error:
try:
self.connection.cancel_transaction()
except LostConnectionError:
pass
error_message = '{exception}{msg}'.format(
exception=error.__class__.__name__,
msg=': ' + str(error) if str(error) else '')
if reserve_jobs:
# show error name and error message (if any)
jobs.error(
self.target.table_name, self._job_key(key),
error_message=error_message, error_stack=traceback.format_exc())
if not suppress_errors or isinstance(error, SystemExit):
raise
else:
logger.error(error)
error_list.append((key, error if return_exception_objects else error_message))
else:
self.connection.commit_transaction()
if reserve_jobs:
jobs.complete(self.target.table_name, self._job_key(key))
finally:
self.__class__._allow_insert = False
for error in pool.imap(_call_populate1, keys):
if error is not None:
error_list.append(error)
self.connection.connect() # reconnect parent process to MySQL server

# place back the original signal handler
# restore original signal handler:
if reserve_jobs:
signal.signal(signal.SIGTERM, old_handler)
return error_list

if suppress_errors:
return error_list

def _populate1(self, key, jobs, suppress_errors, return_exception_objects):
"""
populates table for one source key, calling self.make inside a transaction.
:param jobs: the jobs table or None if not reserve_jobs
:param key: dict specifying job to populate
:param suppress_errors: bool if errors should be suppressed and returned
:param return_exception_objects: if True, errors must be returned as objects
:return: (key, error) when suppress_errors=True, otherwise None
"""
make = self._make_tuples if hasattr(self, '_make_tuples') else self.make

if jobs is None or jobs.reserve(self.target.table_name, self._job_key(key)):
self.connection.start_transaction()
if key in self.target: # already populated
self.connection.cancel_transaction()
if jobs is not None:
jobs.complete(self.target.table_name, self._job_key(key))
else:
logger.info('Populating: ' + str(key))
self.__class__._allow_insert = True
try:
make(dict(key))
except (KeyboardInterrupt, SystemExit, Exception) as error:
try:
self.connection.cancel_transaction()
except LostConnectionError:
pass
error_message = '{exception}{msg}'.format(
exception=error.__class__.__name__,
msg=': ' + str(error) if str(error) else '')
if jobs is not None:
# show error name and error message (if any)
jobs.error(
self.target.table_name, self._job_key(key),
error_message=error_message, error_stack=traceback.format_exc())
if not suppress_errors or isinstance(error, SystemExit):
raise
else:
logger.error(error)
return key, error if return_exception_objects else error_message
else:
self.connection.commit_transaction()
if jobs is not None:
jobs.complete(self.target.table_name, self._job_key(key))
finally:
self.__class__._allow_insert = False

def progress(self, *restrictions, display=True):
"""
report progress of populating the table
:return: remaining, total -- tuples to be populated
Report the progress of populating the table.
:return: (remaining, total) -- numbers of tuples to be populated
"""
todo = self._jobs_to_do(restrictions)
total = len(todo)
Expand Down
2 changes: 2 additions & 0 deletions datajoint/blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def pack_blob(self, obj):
return self.pack_array(np.array(obj))
if isinstance(obj, (bool, np.bool_)):
return self.pack_array(np.array(obj))
if isinstance(obj, (float, int, complex)):
return self.pack_array(np.array(obj))
if isinstance(obj, (datetime.datetime, datetime.date, datetime.time)):
return self.pack_datetime(obj)
if isinstance(obj, Decimal):
Expand Down
2 changes: 1 addition & 1 deletion datajoint/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def query(self, query, args=(), *, as_dict=False, suppress_warnings=True, reconn
# check cache first:
use_query_cache = bool(self._query_cache)
if use_query_cache and not re.match(r"\s*(SELECT|SHOW)", query):
raise errors.DataJointError("Only SELECT query are allowed when query caching is on.")
raise errors.DataJointError("Only SELECT queries are allowed when query caching is on.")
if use_query_cache:
if not config['query_cache']:
raise errors.DataJointError("Provide filepath dj.config['query_cache'] when using query caching.")
Expand Down
6 changes: 3 additions & 3 deletions datajoint/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class QueryExpression:
"""
_restriction = None
_restriction_attributes = None
_left = [] # True for left joins, False for inner joins
_left = [] # list of booleans True for left joins, False for inner joins
_original_heading = None # heading before projections

# subclasses or instantiators must provide values
Expand Down Expand Up @@ -263,7 +263,7 @@ def join(self, other, semantic_check=True, left=False):
if semantic_check:
assert_join_compatibility(self, other)
join_attributes = set(n for n in self.heading.names if n in other.heading.names)
# needs subquery if FROM class has common attributes with the other's FROM clause
# needs subquery if self's FROM clause has common attributes with other's FROM clause
need_subquery1 = need_subquery2 = bool(
(set(self.original_heading.names) & set(other.original_heading.names))
- join_attributes)
Expand Down Expand Up @@ -306,7 +306,7 @@ def proj(self, *attributes, **named_attributes):
self.proj(...) or self.proj(Ellipsis) -- include all attributes (return self)
self.proj() -- include only primary key
self.proj('attr1', 'attr2') -- include primary key and attributes attr1 and attr2
self.proj(..., '-attr1', '-attr2') -- include attributes except attr1 and attr2
self.proj(..., '-attr1', '-attr2') -- include all attributes except attr1 and attr2
self.proj(name1='attr1') -- include primary key and 'attr1' renamed as name1
self.proj('attr1', dup='(attr1)') -- include primary key and attribute attr1 twice, with the duplicate 'dup'
self.proj(k='abs(attr1)') adds the new attribute k with the value computed as an expression (SQL syntax)
Expand Down
1 change: 1 addition & 0 deletions docs-parts/intro/Releases_lang1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* Bugfix - Fix Python 3.10 compatibility (#983) PR #972
* Bugfix - Allow renaming non-conforming attributes in proj (#982) PR #972
* Add - Expose proxy feature for S3 external stores (#961) PR #962
* Add - implement multiprocessing in populate (#695) PR #704, #969
* Bugfix - Dependencies not properly loaded on populate. (#902) PR #919
* Bugfix - Replace use of numpy aliases of built-in types with built-in type. (#938) PR #939
* Bugfix - `ExternalTable.delete` should not remove row on error (#953) PR #956
Expand Down
3 changes: 1 addition & 2 deletions tests/test_autopopulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ class TestPopulate:
"""
Test base relations: insert, delete
"""

def setUp(self):
self.user = schema.User()
self.subject = schema.Subject()
Expand Down Expand Up @@ -53,7 +52,7 @@ def test_populate(self):

def test_allow_direct_insert(self):
assert_true(self.subject, 'root tables are empty')
key = self.subject.fetch('KEY')[0]
key = self.subject.fetch('KEY', limit=1)[0]
key['experiment_id'] = 1000
key['experiment_date'] = '2018-10-30'
self.experiment.insert1(key, allow_direct_insert=True)
Expand Down
3 changes: 3 additions & 0 deletions tests/test_blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ def test_pack():
x = np.random.randn(10)
assert_array_equal(x, unpack(pack(x)), "Arrays do not match!")

x = 7j
assert_equal(x, unpack(pack(x)), "Complex scalar does not match")

x = np.float32(np.random.randn(3, 4, 5))
assert_array_equal(x, unpack(pack(x)), "Arrays do not match!")

Expand Down