diff --git a/docs/src/compute/key-source.md b/docs/src/compute/key-source.md index 76796ec0c..c9b5d2ce7 100644 --- a/docs/src/compute/key-source.md +++ b/docs/src/compute/key-source.md @@ -45,7 +45,7 @@ definition = """ -> Recording --- sample_rate : float -eeg_data : longblob +eeg_data : """ key_source = Recording & 'recording_type = "EEG"' ``` diff --git a/docs/src/compute/make.md b/docs/src/compute/make.md index 1b5569b65..390be3b7b 100644 --- a/docs/src/compute/make.md +++ b/docs/src/compute/make.md @@ -152,7 +152,7 @@ class ImageAnalysis(dj.Computed): # Complex image analysis results -> Image --- - analysis_result : longblob + analysis_result : processing_time : float """ @@ -188,7 +188,7 @@ class ImageAnalysis(dj.Computed): # Complex image analysis results -> Image --- - analysis_result : longblob + analysis_result : processing_time : float """ diff --git a/docs/src/compute/populate.md b/docs/src/compute/populate.md index 45c863f17..91db7b176 100644 --- a/docs/src/compute/populate.md +++ b/docs/src/compute/populate.md @@ -40,7 +40,7 @@ class FilteredImage(dj.Computed): # Filtered image -> Image --- - filtered_image : longblob + filtered_image : """ def make(self, key): @@ -196,7 +196,7 @@ class ImageAnalysis(dj.Computed): # Complex image analysis results -> Image --- - analysis_result : longblob + analysis_result : processing_time : float """ @@ -230,7 +230,7 @@ class ImageAnalysis(dj.Computed): # Complex image analysis results -> Image --- - analysis_result : longblob + analysis_result : processing_time : float """ diff --git a/docs/src/design/autopopulate-2.0-spec.md b/docs/src/design/autopopulate-2.0-spec.md new file mode 100644 index 000000000..2e471cc5e --- /dev/null +++ b/docs/src/design/autopopulate-2.0-spec.md @@ -0,0 +1,726 @@ +# Autopopulate 2.0 Specification + +## Overview + +This specification redesigns the DataJoint job handling system to provide better visibility, control, and scalability for distributed computing workflows. The new system replaces the schema-level `~jobs` table with per-table job tables that offer richer status tracking, proper referential integrity, and dashboard-friendly monitoring. + +## Problem Statement + +### Current Jobs Table Limitations + +The existing `~jobs` table has significant limitations: + +1. **Limited status tracking**: Only supports `reserved`, `error`, and `ignore` statuses +2. **Functions as an error log**: Cannot efficiently track pending or completed jobs +3. **Poor dashboard visibility**: No way to monitor pipeline progress without querying multiple tables +4. **Key hashing obscures data**: Primary keys are stored as hashes, making debugging difficult +5. **No referential integrity**: Jobs table is independent of computed tables; orphaned jobs can accumulate + +### Key Source Limitations + +1. **Frequent manual modifications**: Subset operations require modifying `key_source` property +2. **Local visibility only**: Custom key sources are not accessible database-wide +3. **Performance bottleneck**: Multiple workers querying `key_source` simultaneously creates contention +4. **Codebase dependency**: Requires full pipeline codebase to determine pending work + +## Proposed Solution + +### Terminology + +- **Stale job**: A pending job whose upstream records have been deleted. The job references keys that no longer exist in `key_source`. Stale jobs are automatically cleaned up by `refresh()`. +- **Orphaned job**: A reserved job from a crashed or terminated process. The worker that reserved the job is no longer running, but the job remains in `reserved` status. Orphaned jobs must be cleared manually (see below). + +### Core Design Principles + +1. **Per-table jobs**: Each computed table gets its own hidden jobs table +2. **FK-derived primary keys**: Jobs table primary key includes only attributes derived from foreign keys in the target table's primary key (not additional primary key attributes) +3. **No FK constraints on jobs**: Jobs tables omit foreign key constraints for performance; stale jobs are cleaned by `refresh()` +4. **Rich status tracking**: Extended status values for full lifecycle visibility +5. **Automatic refresh**: `populate()` automatically refreshes the jobs queue (adding new jobs, removing stale ones) + +## Architecture + +### Jobs Table Structure + +Each `dj.Imported` or `dj.Computed` table `MyTable` will have an associated hidden jobs table `~my_table__jobs` with the following structure: + +``` +# Job queue for MyTable +subject_id : int +session_id : int +... # Only FK-derived primary key attributes (NO foreign key constraints) +--- +status : enum('pending', 'reserved', 'success', 'error', 'ignore') +priority : int # Lower = more urgent (0 = highest priority, default: 5) +created_time : datetime # When job was added to queue +scheduled_time : datetime # Process on or after this time (default: now) +reserved_time : datetime # When job was reserved (null if not reserved) +completed_time : datetime # When job completed (null if not completed) +duration : float # Execution duration in seconds (null if not completed) +error_message : varchar(2047) # Truncated error message +error_stack : mediumblob # Full error traceback +user : varchar(255) # Database user who reserved/completed job +host : varchar(255) # Hostname of worker +pid : int unsigned # Process ID of worker +connection_id : bigint unsigned # MySQL connection ID +version : varchar(255) # Code version (git hash, package version, etc.) +``` + +**Important**: The jobs table primary key includes only those attributes that come through foreign keys in the target table's primary key. Additional primary key attributes (if any) are excluded. This means: +- If a target table has primary key `(-> Subject, -> Session, method)`, the jobs table has primary key `(subject_id, session_id)` only +- Multiple target rows may map to a single job entry when additional PK attributes exist +- Jobs tables have **no foreign key constraints** for performance (stale jobs handled by `refresh()`) + +### Access Pattern + +Jobs are accessed as a property of the computed table: + +```python +# Current pattern (schema-level) +schema.jobs + +# New pattern (per-table) +MyTable.jobs + +# Examples +FilteredImage.jobs # Access jobs table +FilteredImage.jobs & 'status="error"' # Query errors +FilteredImage.jobs.refresh() # Refresh job queue +``` + +### Status Values + +| Status | Description | +|--------|-------------| +| `pending` | Job is queued and ready to be processed | +| `reserved` | Job is currently being processed by a worker | +| `success` | Job completed successfully (optional, depends on settings) | +| `error` | Job failed with an error | +| `ignore` | Job should be skipped (manually set, not part of automatic transitions) | + +### Status Transitions + +```mermaid +stateDiagram-v2 + state "(none)" as none1 + state "(none)" as none2 + none1 --> pending : refresh() + none1 --> ignore : ignore() + pending --> reserved : reserve() + reserved --> none2 : complete() + reserved --> success : complete()* + reserved --> error : error() + success --> pending : refresh()* + error --> none2 : delete() + success --> none2 : delete() + ignore --> none2 : delete() +``` + +- `complete()` deletes the job entry (default when `jobs.keep_completed=False`) +- `complete()*` keeps the job as `success` (when `jobs.keep_completed=True`) +- `refresh()*` re-pends a `success` job if its key is in `key_source` but not in target + +**Transition methods:** +- `refresh()` — Adds new jobs as `pending`; also re-pends `success` jobs if key is in `key_source` but not in target +- `ignore()` — Marks a key as `ignore` (can be called on keys not yet in jobs table) +- `reserve()` — Marks a pending job as `reserved` before calling `make()` +- `complete()` — Marks reserved job as `success`, or deletes it (based on `jobs.keep_completed` setting) +- `error()` — Marks reserved job as `error` with message and stack trace +- `delete()` — Inherited from `delete_quick()`; use `(jobs & condition).delete()` pattern + +**Manual status control:** +- `ignore` is set manually via `jobs.ignore(key)` and is not part of automatic transitions +- Jobs with `status='ignore'` are skipped by `populate()` and `refresh()` +- To reset an ignored job, delete it and call `refresh()`: `jobs.ignored.delete(); jobs.refresh()` + +## API Design + +### JobsTable Class + +```python +class JobsTable(Table): + """Hidden table managing job queue for a computed table.""" + + @property + def definition(self) -> str: + """Dynamically generated based on parent table's primary key.""" + ... + + def refresh( + self, + *restrictions, + delay: float = 0, + priority: int = 5, + stale_timeout: float = None + ) -> dict: + """ + Refresh the jobs queue: add new jobs and remove stale ones. + + Operations performed: + 1. Add new jobs: (key_source & restrictions) - target - jobs → insert as 'pending' + 2. Remove stale jobs: pending jobs older than stale_timeout whose keys + are no longer in key_source (upstream records were deleted) + + Args: + restrictions: Conditions to filter key_source + delay: Seconds from now until jobs become available for processing. + Default: 0 (jobs are immediately available). + Uses database server time to avoid client clock synchronization issues. + priority: Priority for new jobs (lower = more urgent). Default: 5 + stale_timeout: Seconds after which pending jobs are checked for staleness. + Jobs older than this are removed if their key is no longer + in key_source. Default from config: jobs.stale_timeout (3600s) + + Returns: + {'added': int, 'removed': int} - counts of jobs added and stale jobs removed + """ + ... + + def reserve(self, key: dict) -> bool: + """ + Attempt to reserve a job for processing. + + Updates status to 'reserved' if currently 'pending' and scheduled_time <= now. + No locking is used; rare conflicts are resolved by the make() transaction. + + Returns: + True if reservation successful, False if job not found or not pending. + """ + ... + + def complete(self, key: dict, duration: float = None) -> None: + """ + Mark a job as successfully completed. + + Updates status to 'success', records duration and completion time. + """ + ... + + def error(self, key: dict, error_message: str, error_stack: str = None) -> None: + """ + Mark a job as failed with error details. + + Updates status to 'error', records error message and stack trace. + """ + ... + + def ignore(self, key: dict) -> None: + """ + Mark a job to be ignored (skipped during populate). + + To reset an ignored job, delete it and call refresh(). + """ + ... + + # delete() is inherited from delete_quick() - no confirmation required + # Usage: (jobs & condition).delete() or jobs.errors.delete() + + @property + def pending(self) -> QueryExpression: + """Return query for pending jobs.""" + return self & 'status="pending"' + + @property + def reserved(self) -> QueryExpression: + """Return query for reserved jobs.""" + return self & 'status="reserved"' + + @property + def errors(self) -> QueryExpression: + """Return query for error jobs.""" + return self & 'status="error"' + + @property + def ignored(self) -> QueryExpression: + """Return query for ignored jobs.""" + return self & 'status="ignore"' + + @property + def completed(self) -> QueryExpression: + """Return query for completed jobs.""" + return self & 'status="success"' +``` + +### AutoPopulate Integration + +The `populate()` method is updated to use the new jobs table: + +```python +def populate( + self, + *restrictions, + suppress_errors: bool = False, + return_exception_objects: bool = False, + reserve_jobs: bool = False, + order: str = "original", + limit: int = None, + max_calls: int = None, + display_progress: bool = False, + processes: int = 1, + make_kwargs: dict = None, + # New parameters + priority: int = None, # Only process jobs at this priority or more urgent (lower values) + refresh: bool = True, # Refresh jobs queue if no pending jobs available +) -> dict: + """ + Populate the table by calling make() for each missing entry. + + New behavior with reserve_jobs=True: + 1. Fetch all non-stale pending jobs (ordered by priority ASC, scheduled_time ASC) + 2. For each pending job: + a. Mark job as 'reserved' (per-key, before make) + b. Call make(key) + c. On success: mark job as 'success' or delete (based on keep_completed) + d. On error: mark job as 'error' with message/stack + 3. If refresh=True and no pending jobs were found, call self.jobs.refresh() + and repeat from step 1 + 4. Continue until no more pending jobs or max_calls reached + """ + ... +``` + +### Progress and Monitoring + +```python +# Current progress reporting +remaining, total = MyTable.progress() + +# Enhanced progress with jobs table +MyTable.jobs.progress() # Returns detailed status breakdown + +# Example output: +# { +# 'pending': 150, +# 'reserved': 3, +# 'success': 847, +# 'error': 12, +# 'ignore': 5, +# 'total': 1017 +# } +``` + +### Priority and Scheduling + +Priority and scheduling are handled via `refresh()` parameters. Lower priority values are more urgent (0 = highest priority). Scheduling uses relative time (seconds from now) based on database server time. + +```python +# Add urgent jobs (priority=0 is most urgent) +MyTable.jobs.refresh(priority=0) + +# Add normal jobs (default priority=5) +MyTable.jobs.refresh() + +# Add low-priority background jobs +MyTable.jobs.refresh(priority=10) + +# Schedule jobs for future processing (2 hours from now) +MyTable.jobs.refresh(delay=2*60*60) # 7200 seconds + +# Schedule jobs for tomorrow (24 hours from now) +MyTable.jobs.refresh(delay=24*60*60) + +# Combine: urgent jobs with 1-hour delay +MyTable.jobs.refresh(priority=0, delay=3600) + +# Add urgent jobs for specific subjects +MyTable.jobs.refresh(Subject & 'priority="urgent"', priority=0) +``` + +## Implementation Details + +### Table Naming Convention + +Jobs tables follow the existing hidden table naming pattern: +- Table `FilteredImage` (stored as `__filtered_image`) +- Jobs table: `~filtered_image__jobs` (stored as `_filtered_image__jobs`) + +### Primary Key Derivation + +The jobs table primary key includes only those attributes derived from foreign keys in the target table's primary key: + +```python +# Example 1: FK-only primary key (simple case) +@schema +class FilteredImage(dj.Computed): + definition = """ + -> Image + --- + filtered_image : + """ +# Jobs table primary key: (image_id) — same as target + +# Example 2: Target with additional PK attribute +@schema +class Analysis(dj.Computed): + definition = """ + -> Recording + analysis_method : varchar(32) # Additional PK attribute + --- + result : float + """ +# Jobs table primary key: (recording_id) — excludes 'analysis_method' +# One job entry covers all analysis_method values for a given recording +``` + +The jobs table has **no foreign key constraints** for performance reasons. + +### Stale Job Handling + +Stale jobs are pending jobs whose upstream records have been deleted. Since there are no FK constraints on jobs tables, these jobs remain until cleaned up by `refresh()`: + +```python +# refresh() handles stale jobs automatically +result = FilteredImage.jobs.refresh() +# Returns: {'added': 10, 'removed': 3} # 3 stale jobs cleaned up + +# Stale detection logic: +# 1. Find pending jobs where created_time < (now - stale_timeout) +# 2. Check if their keys still exist in key_source +# 3. Remove pending jobs whose keys no longer exist +``` + +**Why not use foreign key cascading deletes?** +- FK constraints add overhead on every insert/update/delete operation +- Jobs tables are high-traffic (frequent reservations and status updates) +- Stale jobs are harmless until refresh—they simply won't match key_source +- The `refresh()` approach is more efficient for batch cleanup + +### Table Drop and Alter Behavior + +When an auto-populated table is **dropped**, its associated jobs table is automatically dropped: + +```python +# Dropping FilteredImage also drops ~filtered_image__jobs +FilteredImage.drop() +``` + +When an auto-populated table is **altered** (e.g., primary key changes), the jobs table is dropped and can be recreated via `refresh()`: + +```python +# Alter that changes primary key structure +# Jobs table is dropped since its structure no longer matches +FilteredImage.alter() + +# Recreate jobs table with new structure +FilteredImage.jobs.refresh() +``` + +### Lazy Table Creation + +Jobs tables are created automatically on first use: + +```python +# First call to populate with reserve_jobs=True creates the jobs table +FilteredImage.populate(reserve_jobs=True) +# Creates ~filtered_image__jobs if it doesn't exist, then populates + +# Alternatively, explicitly create/refresh the jobs table +FilteredImage.jobs.refresh() +``` + +The jobs table is created with a primary key derived from the target table's foreign key attributes. + +### Conflict Resolution + +Conflict resolution relies on the transaction surrounding each `make()` call. This applies regardless of whether `reserve_jobs=True` or `reserve_jobs=False`: + +- With `reserve_jobs=False`: Workers query `key_source` directly and may attempt the same key +- With `reserve_jobs=True`: Job reservation reduces conflicts but doesn't eliminate them entirely + +When two workers attempt to populate the same key: +1. Both call `make()` for the same key +2. First worker's `make()` transaction commits, inserting the result +3. Second worker's `make()` transaction fails with duplicate key error +4. Second worker catches the error, and the job returns to `pending` or `(none)` state + +**Important**: Only errors that occur *inside* `make()` are logged with `error` status. Duplicate key errors from collisions occur outside the `make()` logic and are handled silently—the job is either retried or reverts to `pending`/`(none)`. This distinction ensures the error log contains only genuine computation failures, not coordination artifacts. + +**Why this is acceptable**: +- The `make()` transaction guarantees data integrity +- Duplicate key error is a clean, expected signal (not a real error) +- With `reserve_jobs=True`, conflicts are rare (requires near-simultaneous reservation) +- Wasted computation is minimal compared to locking complexity + +### Job Reservation vs Pre-Partitioning + +The job reservation mechanism (`reserve_jobs=True`) allows workers to dynamically claim jobs from a shared queue. However, some orchestration systems may prefer to **pre-partition** jobs before distributing them to workers: + +```python +# Pre-partitioning example: orchestrator divides work explicitly +all_pending = FilteredImage.jobs.pending.fetch("KEY") + +# Split jobs among workers (e.g., by worker index) +n_workers = 4 +for worker_id in range(n_workers): + worker_jobs = all_pending[worker_id::n_workers] # Round-robin assignment + # Send worker_jobs to worker via orchestration system (Slurm, K8s, etc.) + +# Worker receives its assigned keys and processes them directly +for key in assigned_keys: + FilteredImage.populate(key, reserve_jobs=False) +``` + +**When to use each approach**: + +| Approach | Use Case | +|----------|----------| +| **Dynamic reservation** (`reserve_jobs=True`) | Simple setups, variable job durations, workers that start/stop dynamically | +| **Pre-partitioning** | Batch schedulers (Slurm, PBS), predictable job counts, avoiding reservation overhead | + +Both approaches benefit from the same transaction-based conflict resolution as a safety net. + +### Orphaned Job Handling + +Orphaned jobs are reserved jobs from crashed or terminated processes. The API does not provide an algorithmic method for detecting or clearing orphaned jobs because this is dependent on the orchestration system (e.g., Slurm job IDs, Kubernetes pod status, process heartbeats). + +Users must manually clear orphaned jobs using the `delete()` method: + +```python +# Delete all reserved jobs (use with caution - may kill active jobs!) +MyTable.jobs.reserved.delete() + +# Delete reserved jobs from a specific host that crashed +(MyTable.jobs.reserved & 'host="crashed-node"').delete() + +# Delete reserved jobs older than 1 hour (likely orphaned) +(MyTable.jobs.reserved & 'reserved_time < NOW() - INTERVAL 1 HOUR').delete() + +# Delete and re-add as pending +MyTable.jobs.reserved.delete() +MyTable.jobs.refresh() +``` + +**Note**: Deleting a reserved job does not terminate the running worker—it simply removes the reservation record. If the worker is still running, it will complete its `make()` call. If the job is then refreshed as pending and picked up by another worker, duplicated work may occur. Coordinate with your orchestration system to identify truly orphaned jobs before clearing them. + +## Configuration Options + +New configuration settings for job management: + +```python +# In datajoint config +dj.config['jobs.auto_refresh'] = True # Auto-refresh on populate (default: True) +dj.config['jobs.keep_completed'] = False # Keep success records (default: False) +dj.config['jobs.stale_timeout'] = 3600 # Seconds before pending job is considered stale (default: 3600) +dj.config['jobs.default_priority'] = 5 # Default priority for new jobs (lower = more urgent) +``` + +## Usage Examples + +### Basic Distributed Computing + +```python +# Worker 1 +FilteredImage.populate(reserve_jobs=True) + +# Worker 2 (can run simultaneously) +FilteredImage.populate(reserve_jobs=True) + +# Monitor progress +print(FilteredImage.jobs.progress()) +``` + +### Priority-Based Processing + +```python +# Add urgent jobs (priority=0 is most urgent) +urgent_subjects = Subject & 'priority="urgent"' +FilteredImage.jobs.refresh(urgent_subjects, priority=0) + +# Workers will process lowest-priority-value jobs first +FilteredImage.populate(reserve_jobs=True) +``` + +### Scheduled Processing + +```python +# Schedule jobs for overnight processing (8 hours from now) +FilteredImage.jobs.refresh('subject_id > 100', delay=8*60*60) + +# Only jobs whose scheduled_time <= now will be processed +FilteredImage.populate(reserve_jobs=True) +``` + +### Error Recovery + +```python +# View errors +errors = FilteredImage.jobs.errors.fetch(as_dict=True) +for err in errors: + print(f"Key: {err['subject_id']}, Error: {err['error_message']}") + +# Delete specific error jobs after fixing the issue +(FilteredImage.jobs & 'subject_id=42').delete() + +# Delete all error jobs +FilteredImage.jobs.errors.delete() + +# Re-add deleted jobs as pending (if keys still in key_source) +FilteredImage.jobs.refresh() +``` + +### Dashboard Queries + +```python +# Get pipeline-wide status using schema.jobs +def pipeline_status(schema): + return { + jt.target.table_name: jt.progress() + for jt in schema.jobs + } + +# Example output: +# { +# 'FilteredImage': {'pending': 150, 'reserved': 3, 'success': 847, 'error': 12}, +# 'Analysis': {'pending': 500, 'reserved': 0, 'success': 0, 'error': 0}, +# } + +# Refresh all jobs tables in the schema +for jobs_table in schema.jobs: + jobs_table.refresh() + +# Get all errors across the pipeline +all_errors = [] +for jt in schema.jobs: + errors = jt.errors.fetch(as_dict=True) + for err in errors: + err['_table'] = jt.target.table_name + all_errors.append(err) +``` + +## Backward Compatibility + +### Migration + +This is a major release. The legacy schema-level `~jobs` table is replaced by per-table jobs tables: + +- **Legacy `~jobs` table**: No longer used; can be dropped manually if present +- **New jobs tables**: Created automatically on first `populate(reserve_jobs=True)` call +- **No parallel support**: Teams should migrate cleanly to the new system + +### API Compatibility + +The `schema.jobs` property returns a list of all jobs table objects for auto-populated tables in the schema: + +```python +# Returns list of JobsTable objects +schema.jobs +# [FilteredImage.jobs, Analysis.jobs, ...] + +# Iterate over all jobs tables +for jobs_table in schema.jobs: + print(f"{jobs_table.target.table_name}: {jobs_table.progress()}") + +# Query all errors across the schema +all_errors = [job for jt in schema.jobs for job in jt.errors.fetch(as_dict=True)] + +# Refresh all jobs tables +for jobs_table in schema.jobs: + jobs_table.refresh() +``` + +This replaces the legacy single `~jobs` table with direct access to per-table jobs. + +## Hazard Analysis + +This section identifies potential hazards and their mitigations. + +### Race Conditions + +| Hazard | Description | Mitigation | +|--------|-------------|------------| +| **Simultaneous reservation** | Two workers reserve the same pending job at nearly the same time | Acceptable: duplicate `make()` calls are resolved by transaction—second worker gets duplicate key error | +| **Reserve during refresh** | Worker reserves a job while another process is running `refresh()` | No conflict: `refresh()` adds new jobs and removes stale ones; reservation updates existing rows | +| **Concurrent refresh calls** | Multiple processes call `refresh()` simultaneously | Acceptable: may result in duplicate insert attempts, but primary key constraint prevents duplicates | +| **Complete vs delete race** | One process completes a job while another deletes it | Acceptable: one operation succeeds, other becomes no-op (row not found) | + +### State Transitions + +| Hazard | Description | Mitigation | +|--------|-------------|------------| +| **Invalid state transition** | Code attempts illegal transition (e.g., pending → success) | Implementation enforces valid transitions; invalid attempts raise error | +| **Stuck in reserved** | Worker crashes while job is reserved (orphaned job) | Manual intervention required: `jobs.reserved.delete()` (see Orphaned Job Handling) | +| **Success re-pended unexpectedly** | `refresh()` re-pends a success job when user expected it to stay | Only occurs if `keep_completed=True` AND key exists in `key_source` but not in target; document clearly | +| **Ignore not respected** | Ignored jobs get processed anyway | Implementation must skip `status='ignore'` in `populate()` job fetching | + +### Data Integrity + +| Hazard | Description | Mitigation | +|--------|-------------|------------| +| **Stale job processed** | Job references deleted upstream data | `make()` will fail or produce invalid results; `refresh()` cleans stale jobs before processing | +| **Jobs table out of sync** | Jobs table doesn't match `key_source` | `refresh()` synchronizes; call periodically or rely on `populate(refresh=True)` | +| **Partial make failure** | `make()` partially succeeds then fails | DataJoint transaction rollback ensures atomicity; job marked as error | +| **Error message truncation** | Error details exceed `varchar(2047)` | Full stack stored in `error_stack` (mediumblob); `error_message` is summary only | + +### Performance + +| Hazard | Description | Mitigation | +|--------|-------------|------------| +| **Large jobs table** | Jobs table grows very large with `keep_completed=True` | Default is `keep_completed=False`; provide guidance on periodic cleanup | +| **Slow refresh on large key_source** | `refresh()` queries entire `key_source` | Can restrict refresh to subsets: `jobs.refresh(Subject & 'lab="smith"')` | +| **Many jobs tables per schema** | Schema with many computed tables has many jobs tables | Jobs tables are lightweight; only created on first use | + +### Operational + +| Hazard | Description | Mitigation | +|--------|-------------|------------| +| **Accidental job deletion** | User runs `jobs.delete()` without restriction | `delete()` inherits from `delete_quick()` (no confirmation); users must apply restrictions carefully | +| **Clearing active jobs** | User clears reserved jobs while workers are still running | May cause duplicated work if job is refreshed and picked up again; coordinate with orchestrator | +| **Priority confusion** | User expects higher number = higher priority | Document clearly: lower values are more urgent (0 = highest priority) | + +### Migration + +| Hazard | Description | Mitigation | +|--------|-------------|------------| +| **Legacy ~jobs table conflict** | Old `~jobs` table exists alongside new per-table jobs | Systems are independent; legacy table can be dropped manually | +| **Mixed version workers** | Some workers use old system, some use new | Major release; do not support mixed operation—require full migration | +| **Lost error history** | Migrating loses error records from legacy table | Document migration procedure; users can export legacy errors before migration | + +## Future Extensions + +- [ ] Web-based dashboard for job monitoring +- [ ] Webhook notifications for job completion/failure +- [ ] Job dependencies (job B waits for job A) +- [ ] Resource tagging (GPU required, high memory, etc.) +- [ ] Retry policies (max retries, exponential backoff) +- [ ] Job grouping/batching for efficiency +- [ ] Integration with external schedulers (Slurm, PBS, etc.) + +## Rationale + +### Why Not External Orchestration? + +The team considered integrating external tools like Airflow or Flyte but rejected this approach because: + +1. **Deployment complexity**: External orchestrators require significant infrastructure +2. **Maintenance burden**: Additional systems to maintain and monitor +3. **Accessibility**: Not all DataJoint users have access to orchestration platforms +4. **Tight integration**: DataJoint's transaction model requires close coordination + +The built-in jobs system provides 80% of the value with minimal additional complexity. + +### Why Per-Table Jobs? + +Per-table jobs tables provide: + +1. **Better isolation**: Jobs for one table don't affect others +2. **Simpler queries**: No need to filter by table_name +3. **Native keys**: Primary keys are readable, not hashed +4. **High performance**: No FK constraints means minimal overhead on job operations +5. **Scalability**: Each table's jobs can be indexed independently + +### Why Remove Key Hashing? + +The current system hashes primary keys to support arbitrary key types. The new system uses native keys because: + +1. **Readability**: Debugging is much easier with readable keys +2. **Query efficiency**: Native keys can use table indexes +3. **Foreign keys**: Hash-based keys cannot participate in foreign key relationships +4. **Simplicity**: No need for hash computation and comparison + +### Why FK-Derived Primary Keys Only? + +The jobs table primary key includes only attributes derived from foreign keys in the target table's primary key. This design: + +1. **Aligns with key_source**: The `key_source` query naturally produces keys matching the FK-derived attributes +2. **Simplifies job identity**: A job's identity is determined by its upstream dependencies +3. **Handles additional PK attributes**: When targets have additional PK attributes (e.g., `method`), one job covers all values for that attribute diff --git a/docs/src/design/integrity.md b/docs/src/design/integrity.md index cb7122755..393103522 100644 --- a/docs/src/design/integrity.md +++ b/docs/src/design/integrity.md @@ -142,7 +142,7 @@ definition = """ -> EEGRecording channel_idx : int --- -channel_data : longblob +channel_data : """ ``` ![doc_1-many](../images/doc_1-many.png){: style="align:center"} diff --git a/docs/src/design/tables/attributes.md b/docs/src/design/tables/attributes.md index f3877cec9..2e8105e7c 100644 --- a/docs/src/design/tables/attributes.md +++ b/docs/src/design/tables/attributes.md @@ -48,9 +48,10 @@ fractional digits. Because of its well-defined precision, `decimal` values can be used in equality comparison and be included in primary keys. -- `longblob`: arbitrary numeric array (e.g. matrix, image, structure), up to 4 +- `longblob`: raw binary data, up to 4 [GiB](http://en.wikipedia.org/wiki/Gibibyte) in size. - Numeric arrays are compatible between MATLAB and Python (NumPy). + Stores and returns raw bytes without serialization. + For serialized Python objects (arrays, dicts, etc.), use `` instead. The `longblob` and other `blob` datatypes can be configured to store data [externally](../../sysadmin/external-store.md) by using the `blob@store` syntax. @@ -71,6 +72,10 @@ info). These types abstract certain kinds of non-database data to facilitate use together with DataJoint. +- ``: DataJoint's native serialization format for Python objects. Supports +NumPy arrays, dicts, lists, datetime objects, and nested structures. Compatible with +MATLAB. See [custom types](customtype.md) for details. + - `object`: managed [file and folder storage](object.md) with support for direct writes (Zarr, HDF5) and fsspec integration. Recommended for new pipelines. @@ -80,6 +85,10 @@ sending/receiving an opaque data file to/from a DataJoint pipeline. - `filepath@store`: a [filepath](filepath.md) used to link non-DataJoint managed files into a DataJoint pipeline. +- ``: a [custom attribute type](customtype.md) that defines bidirectional +conversion between Python objects and database storage formats. Use this to store +complex data types like graphs, domain-specific objects, or custom data structures. + ## Numeric type aliases DataJoint provides convenient type aliases that map to standard MySQL numeric types. diff --git a/docs/src/design/tables/customtype.md b/docs/src/design/tables/customtype.md index aad194ff5..267e0420b 100644 --- a/docs/src/design/tables/customtype.md +++ b/docs/src/design/tables/customtype.md @@ -1,4 +1,4 @@ -# Custom Types +# Custom Attribute Types In modern scientific research, data pipelines often involve complex workflows that generate diverse data types. From high-dimensional imaging data to machine learning @@ -12,69 +12,603 @@ traditional relational databases. For example: + Computational biologists might store fitted machine learning models or parameter objects for downstream predictions. -To handle these diverse needs, DataJoint provides the `dj.AttributeAdapter` method. It +To handle these diverse needs, DataJoint provides the **AttributeType** system. It enables researchers to store and retrieve complex, non-standard data types—like Python objects or data structures—in a relational database while maintaining the reproducibility, modularity, and query capabilities required for scientific workflows. -## Uses in Scientific Research +## Overview -Imagine a neuroscience lab studying neural connectivity. Researchers might generate -graphs (e.g., networkx.Graph) to represent connections between brain regions, where: +Custom attribute types define bidirectional conversion between: -+ Nodes are brain regions. -+ Edges represent connections weighted by signal strength or another metric. +- **Python objects** (what your code works with) +- **Storage format** (what gets stored in the database) -Storing these graph objects in a database alongside other experimental data (e.g., -subject metadata, imaging parameters) ensures: +``` +┌─────────────────┐ encode() ┌─────────────────┐ +│ Python Object │ ───────────────► │ Storage Type │ +│ (e.g. Graph) │ │ (e.g. blob) │ +└─────────────────┘ decode() └─────────────────┘ + ◄─────────────── +``` + +## Defining Custom Types + +Create a custom type by subclassing `dj.AttributeType` and implementing the required +methods: + +```python +import datajoint as dj +import networkx as nx + +@dj.register_type +class GraphType(dj.AttributeType): + """Custom type for storing networkx graphs.""" + + # Required: unique identifier used in table definitions + type_name = "graph" + + # Required: underlying DataJoint storage type + dtype = "longblob" + + def encode(self, graph, *, key=None): + """Convert graph to storable format (called on INSERT).""" + return list(graph.edges) + + def decode(self, edges, *, key=None): + """Convert stored data back to graph (called on FETCH).""" + return nx.Graph(edges) +``` + +### Required Components + +| Component | Description | +|-----------|-------------| +| `type_name` | Unique identifier used in table definitions with `` syntax | +| `dtype` | Underlying DataJoint type for storage (e.g., `"longblob"`, `"varchar(255)"`, `"json"`) | +| `encode(value, *, key=None)` | Converts Python object to storable format | +| `decode(stored, *, key=None)` | Converts stored data back to Python object | + +### Using Custom Types in Tables + +Once registered, use the type in table definitions with angle brackets: + +```python +@schema +class Connectivity(dj.Manual): + definition = """ + conn_id : int + --- + conn_graph = null : # Uses the GraphType we defined + """ +``` + +Insert and fetch work seamlessly: + +```python +import networkx as nx + +# Insert - encode() is called automatically +g = nx.lollipop_graph(4, 2) +Connectivity.insert1({"conn_id": 1, "conn_graph": g}) + +# Fetch - decode() is called automatically +result = (Connectivity & "conn_id = 1").fetch1("conn_graph") +assert isinstance(result, nx.Graph) +``` + +## Type Registration + +### Decorator Registration + +The simplest way to register a type is with the `@dj.register_type` decorator: + +```python +@dj.register_type +class MyType(dj.AttributeType): + type_name = "my_type" + ... +``` + +### Direct Registration + +You can also register types explicitly: + +```python +class MyType(dj.AttributeType): + type_name = "my_type" + ... + +dj.register_type(MyType) +``` + +### Listing Registered Types + +```python +# List all registered type names +print(dj.list_types()) +``` + +## Validation + +Add data validation by overriding the `validate()` method. It's called automatically +before `encode()` during INSERT operations: + +```python +@dj.register_type +class PositiveArrayType(dj.AttributeType): + type_name = "positive_array" + dtype = "longblob" + + def validate(self, value): + """Ensure all values are positive.""" + import numpy as np + if not isinstance(value, np.ndarray): + raise TypeError(f"Expected numpy array, got {type(value).__name__}") + if np.any(value < 0): + raise ValueError("Array must contain only positive values") + + def encode(self, array, *, key=None): + return array -1. Centralized Data Management: All experimental data and analysis results are stored - together for easy access and querying. -2. Reproducibility: The exact graph objects used in analysis can be retrieved later for - validation or further exploration. -3. Scalability: Graph data can be integrated into workflows for larger datasets or - across experiments. + def decode(self, stored, *, key=None): + return stored +``` + +## Storage Types (dtype) + +The `dtype` property specifies how data is stored in the database: + +| dtype | Use Case | Stored Format | +|-------|----------|---------------| +| `"longblob"` | Complex Python objects, arrays | Serialized binary | +| `"blob"` | Smaller objects | Serialized binary | +| `"json"` | JSON-serializable data | JSON string | +| `"varchar(N)"` | String representations | Text | +| `"int"` | Integer identifiers | Integer | +| `"blob@store"` | Large objects in external storage | UUID reference | +| `"object"` | Files/folders in object storage | JSON metadata | +| `""` | Chain to another custom type | Varies | + +### External Storage + +For large data, use external blob storage: + +```python +@dj.register_type +class LargeArrayType(dj.AttributeType): + type_name = "large_array" + dtype = "blob@mystore" # Uses external store named "mystore" + + def encode(self, array, *, key=None): + return array + + def decode(self, stored, *, key=None): + return stored +``` + +## Type Chaining + +Custom types can build on other custom types by referencing them in `dtype`: + +```python +@dj.register_type +class CompressedGraphType(dj.AttributeType): + type_name = "compressed_graph" + dtype = "" # Chain to the GraphType + + def encode(self, graph, *, key=None): + # Compress before passing to GraphType + return self._compress(graph) + + def decode(self, stored, *, key=None): + # GraphType's decode already ran + return self._decompress(stored) +``` + +DataJoint automatically resolves the chain to find the final storage type. + +## The Key Parameter + +The `key` parameter provides access to primary key values during encode/decode +operations. This is useful when the conversion depends on record context: + +```python +@dj.register_type +class ContextAwareType(dj.AttributeType): + type_name = "context_aware" + dtype = "longblob" + + def encode(self, value, *, key=None): + if key and key.get("version") == 2: + return self._encode_v2(value) + return self._encode_v1(value) + + def decode(self, stored, *, key=None): + if key and key.get("version") == 2: + return self._decode_v2(stored) + return self._decode_v1(stored) +``` + +## Publishing Custom Types as Packages -However, since graphs are not natively supported by relational databases, here’s where -`dj.AttributeAdapter` becomes essential. It allows researchers to define custom logic for -serializing graphs (e.g., as edge lists) and deserializing them back into Python -objects, bridging the gap between advanced data types and the database. +Custom types can be distributed as installable packages using Python entry points. +This allows types to be automatically discovered when the package is installed. -### Example: Storing Graphs in DataJoint +### Package Structure -To store a networkx.Graph object in a DataJoint table, researchers can define a custom -attribute type in a datajoint table class: +``` +dj-graph-types/ +├── pyproject.toml +└── src/ + └── dj_graph_types/ + ├── __init__.py + └── types.py +``` + +### pyproject.toml + +```toml +[project] +name = "dj-graph-types" +version = "1.0.0" + +[project.entry-points."datajoint.types"] +graph = "dj_graph_types.types:GraphType" +weighted_graph = "dj_graph_types.types:WeightedGraphType" +``` + +### Type Implementation ```python +# src/dj_graph_types/types.py import datajoint as dj +import networkx as nx -class GraphAdapter(dj.AttributeAdapter): +class GraphType(dj.AttributeType): + type_name = "graph" + dtype = "longblob" + + def encode(self, graph, *, key=None): + return list(graph.edges) + + def decode(self, edges, *, key=None): + return nx.Graph(edges) + +class WeightedGraphType(dj.AttributeType): + type_name = "weighted_graph" + dtype = "longblob" + + def encode(self, graph, *, key=None): + return [(u, v, d) for u, v, d in graph.edges(data=True)] + + def decode(self, edges, *, key=None): + g = nx.Graph() + g.add_weighted_edges_from(edges) + return g +``` + +### Usage After Installation + +```bash +pip install dj-graph-types +``` + +```python +# Types are automatically available after package installation +@schema +class MyTable(dj.Manual): + definition = """ + id : int + --- + network : + weighted_network : + """ +``` + +## Complete Example + +Here's a complete example demonstrating custom types for a neuroscience workflow: + +```python +import datajoint as dj +import numpy as np + +# Configure DataJoint +dj.config["database.host"] = "localhost" +dj.config["database.user"] = "root" +dj.config["database.password"] = "password" + +# Define custom types +@dj.register_type +class SpikeTrainType(dj.AttributeType): + """Efficient storage for sparse spike timing data.""" + type_name = "spike_train" + dtype = "longblob" + + def validate(self, value): + if not isinstance(value, np.ndarray): + raise TypeError("Expected numpy array of spike times") + if value.ndim != 1: + raise ValueError("Spike train must be 1-dimensional") + if not np.all(np.diff(value) >= 0): + raise ValueError("Spike times must be sorted") + + def encode(self, spike_times, *, key=None): + # Store as differences (smaller values, better compression) + return np.diff(spike_times, prepend=0).astype(np.float32) + + def decode(self, stored, *, key=None): + # Reconstruct original spike times + return np.cumsum(stored).astype(np.float64) - attribute_type = 'longblob' # this is how the attribute will be declared + +@dj.register_type +class WaveformType(dj.AttributeType): + """Storage for spike waveform templates with metadata.""" + type_name = "waveform" + dtype = "longblob" + + def encode(self, waveform_dict, *, key=None): + return { + "data": waveform_dict["data"].astype(np.float32), + "sampling_rate": waveform_dict["sampling_rate"], + "channel_ids": list(waveform_dict["channel_ids"]), + } + + def decode(self, stored, *, key=None): + return { + "data": stored["data"].astype(np.float64), + "sampling_rate": stored["sampling_rate"], + "channel_ids": np.array(stored["channel_ids"]), + } + + +# Create schema and tables +schema = dj.schema("ephys_analysis") + +@schema +class Unit(dj.Manual): + definition = """ + unit_id : int + --- + spike_times : + waveform : + quality : enum('good', 'mua', 'noise') + """ + + +# Usage +spike_times = np.array([0.1, 0.15, 0.23, 0.45, 0.67, 0.89]) +waveform = { + "data": np.random.randn(82, 4), + "sampling_rate": 30000, + "channel_ids": [10, 11, 12, 13], +} + +Unit.insert1({ + "unit_id": 1, + "spike_times": spike_times, + "waveform": waveform, + "quality": "good", +}) + +# Fetch - automatically decoded +result = (Unit & "unit_id = 1").fetch1() +print(f"Spike times: {result['spike_times']}") +print(f"Waveform shape: {result['waveform']['data'].shape}") +``` + +## Migration from AttributeAdapter + +The `AttributeAdapter` class is deprecated. Migrate to `AttributeType`: + +### Before (deprecated) + +```python +class GraphAdapter(dj.AttributeAdapter): + attribute_type = "longblob" def put(self, obj): - # convert the nx.Graph object into an edge list - assert isinstance(obj, nx.Graph) return list(obj.edges) def get(self, value): - # convert edge list back into an nx.Graph return nx.Graph(value) - -# instantiate for use as a datajoint type +# Required context-based registration graph = GraphAdapter() +schema = dj.schema("mydb", context={"graph": graph}) +``` + +### After (recommended) + +```python +@dj.register_type +class GraphType(dj.AttributeType): + type_name = "graph" + dtype = "longblob" + + def encode(self, obj, *, key=None): + return list(obj.edges) + + def decode(self, value, *, key=None): + return nx.Graph(value) + +# Global registration - no context needed +schema = dj.schema("mydb") +``` + +### Key Differences + +| Aspect | AttributeAdapter (deprecated) | AttributeType (recommended) | +|--------|-------------------------------|----------------------------| +| Methods | `put()` / `get()` | `encode()` / `decode()` | +| Storage type | `attribute_type` | `dtype` | +| Type name | Variable name in context | `type_name` property | +| Registration | Context dict per schema | Global `@register_type` decorator | +| Validation | Manual | Built-in `validate()` method | +| Distribution | Copy adapter code | Entry point packages | +| Key access | Not available | Optional `key` parameter | + +## Best Practices + +1. **Choose descriptive type names**: Use lowercase with underscores (e.g., `spike_train`, `graph_embedding`) + +2. **Select appropriate storage types**: Use `` for complex objects, `json` for simple structures, external storage for large data + +3. **Add validation**: Use `validate()` to catch data errors early + +4. **Document your types**: Include docstrings explaining the expected input/output formats + +5. **Handle None values**: Your encode/decode methods may receive `None` for nullable attributes + +6. **Consider versioning**: If your encoding format might change, include version information + +7. **Test round-trips**: Ensure `decode(encode(x)) == x` for all valid inputs + +```python +def test_graph_type_roundtrip(): + g = nx.lollipop_graph(4, 2) + t = GraphType() + + encoded = t.encode(g) + decoded = t.decode(encoded) + + assert set(g.edges) == set(decoded.edges) +``` +## Built-in Types -# define a table with a graph attribute -schema = dj.schema('test_graphs') +DataJoint includes a built-in type for explicit blob serialization: +### `` - DataJoint Blob Serialization +The `` type provides explicit control over DataJoint's native binary +serialization. It supports: + +- NumPy arrays (compatible with MATLAB) +- Python dicts, lists, tuples, sets +- datetime objects, Decimals, UUIDs +- Nested data structures +- Optional compression + +```python @schema -class Connectivity(dj.Manual): +class ProcessedData(dj.Manual): definition = """ - conn_id : int + data_id : int + --- + results : # Serialized Python objects + raw_bytes : longblob # Raw bytes (no serialization) + """ +``` + +#### When to Use `` + +- **Serialized data**: When storing Python objects (dicts, arrays, etc.) +- **New tables**: Prefer `` for automatic serialization +- **Migration**: Existing schemas with implicit serialization must migrate + +#### Raw Blob Behavior + +Plain `longblob` (and other blob variants) columns now store and return +**raw bytes** without automatic serialization: + +```python +@schema +class RawData(dj.Manual): + definition = """ + id : int + --- + raw_bytes : longblob # Stores/returns raw bytes + serialized : # Stores Python objects with serialization + """ + +# Raw bytes - no serialization +RawData.insert1({"id": 1, "raw_bytes": b"raw binary data", "serialized": {"key": "value"}}) + +row = (RawData & "id=1").fetch1() +row["raw_bytes"] # Returns: b"raw binary data" +row["serialized"] # Returns: {"key": "value"} +``` + +**Important**: Existing schemas that relied on implicit blob serialization +must be migrated to `` to preserve their behavior. + +## Schema Migration + +When upgrading existing schemas to use explicit type declarations, DataJoint +provides migration utilities. + +### Analyzing Blob Columns + +```python +import datajoint as dj + +schema = dj.schema("my_database") + +# Check migration status +status = dj.migrate.check_migration_status(schema) +print(f"Blob columns: {status['total_blob_columns']}") +print(f"Already migrated: {status['migrated']}") +print(f"Pending migration: {status['pending']}") +``` + +### Generating Migration SQL + +```python +# Preview migration (dry run) +result = dj.migrate.migrate_blob_columns(schema, dry_run=True) +for sql in result['sql_statements']: + print(sql) +``` + +### Applying Migration + +```python +# Apply migration +result = dj.migrate.migrate_blob_columns(schema, dry_run=False) +print(f"Migrated {result['migrated']} columns") +``` + +### Migration Details + +The migration updates MySQL column comments to include the type declaration. +This is a **metadata-only** change - the actual blob data format is unchanged. + +All blob type variants are handled: `tinyblob`, `blob`, `mediumblob`, `longblob`. + +Before migration: +- Column: `longblob` (or `blob`, `mediumblob`, etc.) +- Comment: `user comment` +- Behavior: Auto-serialization (implicit) + +After migration: +- Column: `longblob` (unchanged) +- Comment: `::user comment` +- Behavior: Explicit serialization via `` + +### Updating Table Definitions + +After database migration, update your Python table definitions for consistency: + +```python +# Before +class MyTable(dj.Manual): + definition = """ + id : int + --- + data : longblob # stored data + """ + +# After +class MyTable(dj.Manual): + definition = """ + id : int --- - conn_graph = null : # a networkx.Graph object + data : # stored data """ ``` + +Both definitions work identically after migration, but using `` makes +the serialization explicit and documents the intended behavior. diff --git a/docs/src/design/tables/master-part.md b/docs/src/design/tables/master-part.md index 629bfb8ab..d0f575e4d 100644 --- a/docs/src/design/tables/master-part.md +++ b/docs/src/design/tables/master-part.md @@ -26,8 +26,8 @@ class Segmentation(dj.Computed): -> Segmentation roi : smallint # roi number --- - roi_pixels : longblob # indices of pixels - roi_weights : longblob # weights of pixels + roi_pixels : # indices of pixels + roi_weights : # weights of pixels """ def make(self, key): @@ -101,7 +101,7 @@ definition = """ -> ElectrodeResponse channel: int --- -response: longblob # response of a channel +response: # response of a channel """ ``` diff --git a/src/datajoint/__init__.py b/src/datajoint/__init__.py index 2fba6bd84..405134630 100644 --- a/src/datajoint/__init__.py +++ b/src/datajoint/__init__.py @@ -45,8 +45,12 @@ "kill", "MatCell", "MatStruct", - "AttributeAdapter", + "AttributeType", + "register_type", + "list_types", + "AttributeAdapter", # Deprecated, use AttributeType "errors", + "migrate", "DataJointError", "key", "key_hash", @@ -56,8 +60,10 @@ ] from . import errors +from . import migrate from .admin import kill from .attribute_adapter import AttributeAdapter +from .attribute_type import AttributeType, list_types, register_type from .blob import MatCell, MatStruct from .cli import cli from .connection import Connection, conn diff --git a/src/datajoint/attribute_adapter.py b/src/datajoint/attribute_adapter.py index 12a34f27e..7df566a58 100644 --- a/src/datajoint/attribute_adapter.py +++ b/src/datajoint/attribute_adapter.py @@ -1,61 +1,211 @@ +""" +Legacy attribute adapter module. + +This module provides backward compatibility for the deprecated AttributeAdapter class. +New code should use :class:`datajoint.AttributeType` instead. + +.. deprecated:: 0.15 + Use :class:`datajoint.AttributeType` with ``encode``/``decode`` methods. +""" + import re +import warnings +from typing import Any + +from .attribute_type import AttributeType, get_type, is_type_registered +from .errors import DataJointError -from .errors import DataJointError, _support_adapted_types +# Pattern to detect blob types for internal pack/unpack +_BLOB_PATTERN = re.compile(r"^(tiny|small|medium|long|)blob", re.I) -class AttributeAdapter: +class AttributeAdapter(AttributeType): """ - Base class for adapter objects for user-defined attribute types. + Legacy base class for attribute adapters. + + .. deprecated:: 0.15 + Use :class:`datajoint.AttributeType` with ``encode``/``decode`` methods instead. + + This class provides backward compatibility for existing adapters that use + the ``attribute_type``, ``put()``, and ``get()`` API. + + Migration guide:: + + # Old style (deprecated): + class GraphAdapter(dj.AttributeAdapter): + attribute_type = "longblob" + + def put(self, graph): + return list(graph.edges) + + def get(self, edges): + return nx.Graph(edges) + + # New style (recommended): + @dj.register_type + class GraphType(dj.AttributeType): + type_name = "graph" + dtype = "longblob" + + def encode(self, graph, *, key=None): + return list(graph.edges) + + def decode(self, edges, *, key=None): + return nx.Graph(edges) """ + # Subclasses can set this as a class attribute instead of property + attribute_type: str = None # type: ignore + + def __init__(self): + # Emit deprecation warning on instantiation + warnings.warn( + f"{self.__class__.__name__} uses the deprecated AttributeAdapter API. " + "Migrate to AttributeType with encode/decode methods.", + DeprecationWarning, + stacklevel=2, + ) + @property - def attribute_type(self): + def type_name(self) -> str: """ - :return: a supported DataJoint attribute type to use; e.g. "longblob", "blob@store" + Infer type name from class name for legacy adapters. + + Legacy adapters were identified by their variable name in the context dict, + not by a property. For backward compatibility, we use the lowercase class name. + """ + # Check if a _type_name was explicitly set (for context-based lookup) + if hasattr(self, "_type_name"): + return self._type_name + # Fall back to class name + return self.__class__.__name__.lower() + + @property + def dtype(self) -> str: + """Map legacy attribute_type to new dtype property.""" + attr_type = self.attribute_type + if attr_type is None: + raise NotImplementedError( + f"{self.__class__.__name__} must define 'attribute_type' " "(or migrate to AttributeType with 'dtype')" + ) + return attr_type + + def _is_blob_dtype(self) -> bool: + """Check if dtype is a blob type requiring pack/unpack.""" + return bool(_BLOB_PATTERN.match(self.dtype)) + + def encode(self, value: Any, *, key: dict | None = None) -> Any: """ - raise NotImplementedError("Undefined attribute adapter") + Delegate to legacy put() method, with blob packing if needed. - def get(self, value): + Legacy adapters expect blob.pack to be called after put() when + the dtype is a blob type. This wrapper handles that automatically. """ - convert value retrieved from the the attribute in a table into the adapted type + result = self.put(value) + # Legacy adapters expect blob.pack after put() for blob dtypes + if self._is_blob_dtype(): + from . import blob - :param value: value from the database + result = blob.pack(result) + return result - :return: object of the adapted type + def decode(self, stored: Any, *, key: dict | None = None) -> Any: """ - raise NotImplementedError("Undefined attribute adapter") + Delegate to legacy get() method, with blob unpacking if needed. - def put(self, obj): + Legacy adapters expect blob.unpack to be called before get() when + the dtype is a blob type. This wrapper handles that automatically. """ - convert an object of the adapted type into a value that DataJoint can store in a table attribute + # Legacy adapters expect blob.unpack before get() for blob dtypes + if self._is_blob_dtype(): + from . import blob + + stored = blob.unpack(stored) + return self.get(stored) - :param obj: an object of the adapted type - :return: value to store in the database + def put(self, obj: Any) -> Any: """ - raise NotImplementedError("Undefined attribute adapter") + Convert an object of the adapted type into a storable value. + + .. deprecated:: 0.15 + Override ``encode()`` instead. + Args: + obj: An object of the adapted type. -def get_adapter(context, adapter_name): + Returns: + Value to store in the database. + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement put() or migrate to encode()") + + def get(self, value: Any) -> Any: + """ + Convert a value from the database into the adapted type. + + .. deprecated:: 0.15 + Override ``decode()`` instead. + + Args: + value: Value from the database. + + Returns: + Object of the adapted type. + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement get() or migrate to decode()") + + +def get_adapter(context: dict | None, adapter_name: str) -> AttributeType: """ - Extract the AttributeAdapter object by its name from the context and validate. + Get an attribute type/adapter by name. + + This function provides backward compatibility by checking both: + 1. The global type registry (new system) + 2. The schema context dict (legacy system) + + Args: + context: Schema context dictionary (for legacy adapters). + adapter_name: The adapter/type name, with or without angle brackets. + + Returns: + The AttributeType instance. + + Raises: + DataJointError: If the adapter is not found or invalid. """ - if not _support_adapted_types(): - raise DataJointError("Support for Adapted Attribute types is disabled.") adapter_name = adapter_name.lstrip("<").rstrip(">") + + # First, check the global type registry (new system) + if is_type_registered(adapter_name): + return get_type(adapter_name) + + # Fall back to context-based lookup (legacy system) + if context is None: + raise DataJointError( + f"Attribute type <{adapter_name}> is not registered. " "Use @dj.register_type to register custom types." + ) + try: adapter = context[adapter_name] except KeyError: - raise DataJointError("Attribute adapter '{adapter_name}' is not defined.".format(adapter_name=adapter_name)) - if not isinstance(adapter, AttributeAdapter): raise DataJointError( - "Attribute adapter '{adapter_name}' must be an instance of datajoint.AttributeAdapter".format( - adapter_name=adapter_name - ) + f"Attribute type <{adapter_name}> is not defined. " + "Register it with @dj.register_type or include it in the schema context." ) - if not isinstance(adapter.attribute_type, str) or not re.match(r"^\w", adapter.attribute_type): + + # Validate it's an AttributeType (or legacy AttributeAdapter) + if not isinstance(adapter, AttributeType): raise DataJointError( - "Invalid attribute type {type} in attribute adapter '{adapter_name}'".format( - type=adapter.attribute_type, adapter_name=adapter_name - ) + f"Attribute adapter '{adapter_name}' must be an instance of " + "datajoint.AttributeType (or legacy datajoint.AttributeAdapter)" ) + + # For legacy adapters from context, store the name they were looked up by + if isinstance(adapter, AttributeAdapter): + adapter._type_name = adapter_name + + # Validate the dtype/attribute_type + dtype = adapter.dtype + if not isinstance(dtype, str) or not re.match(r"^\w", dtype): + raise DataJointError(f"Invalid dtype '{dtype}' in attribute type <{adapter_name}>") + return adapter diff --git a/src/datajoint/attribute_type.py b/src/datajoint/attribute_type.py new file mode 100644 index 000000000..9be2d2214 --- /dev/null +++ b/src/datajoint/attribute_type.py @@ -0,0 +1,531 @@ +""" +Custom attribute type system for DataJoint. + +This module provides the AttributeType base class and registration mechanism +for creating custom data types that extend DataJoint's native type system. + +Custom types enable seamless integration of complex Python objects (like NumPy arrays, +graphs, or domain-specific structures) with DataJoint's relational storage. + +Example: + @dj.register_type + class GraphType(dj.AttributeType): + type_name = "graph" + dtype = "longblob" + + def encode(self, graph: nx.Graph) -> list: + return list(graph.edges) + + def decode(self, edges: list) -> nx.Graph: + return nx.Graph(edges) + + # Then use in table definitions: + class MyTable(dj.Manual): + definition = ''' + id : int + --- + data : + ''' +""" + +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +from .errors import DataJointError + +if TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__.split(".")[0]) + +# Global type registry - maps type_name to AttributeType instance +_type_registry: dict[str, AttributeType] = {} +_entry_points_loaded: bool = False + + +class AttributeType(ABC): + """ + Base class for custom DataJoint attribute types. + + Subclass this to create custom types that can be used in table definitions + with the ```` syntax. Custom types define bidirectional conversion + between Python objects and DataJoint's storage format. + + Attributes: + type_name: Unique identifier used in ```` syntax + dtype: Underlying DataJoint storage type + + Example: + @dj.register_type + class GraphType(dj.AttributeType): + type_name = "graph" + dtype = "longblob" + + def encode(self, graph): + return list(graph.edges) + + def decode(self, edges): + import networkx as nx + return nx.Graph(edges) + + The type can then be used in table definitions:: + + class Connectivity(dj.Manual): + definition = ''' + id : int + --- + graph_data : + ''' + """ + + @property + @abstractmethod + def type_name(self) -> str: + """ + Unique identifier for this type, used in table definitions as ````. + + This name must be unique across all registered types. It should be lowercase + with underscores (e.g., "graph", "zarr_array", "compressed_image"). + + Returns: + The type name string without angle brackets. + """ + ... + + @property + @abstractmethod + def dtype(self) -> str: + """ + The underlying DataJoint type used for storage. + + Can be: + - A native type: ``"longblob"``, ``"blob"``, ``"varchar(255)"``, ``"int"``, ``"json"`` + - An external type: ``"blob@store"``, ``"attach@store"`` + - The object type: ``"object"`` + - Another custom type: ``""`` (enables type chaining) + + Returns: + The storage type specification string. + """ + ... + + @abstractmethod + def encode(self, value: Any, *, key: dict | None = None) -> Any: + """ + Convert a Python object to the storable format. + + Called during INSERT operations to transform user-provided objects + into a format suitable for storage in the underlying ``dtype``. + + Args: + value: The Python object to store. + key: Primary key values as a dict. Available when the dtype uses + object storage and may be needed for path construction. + + Returns: + Value in the format expected by ``dtype``. For example: + - For ``dtype="longblob"``: any picklable Python object + - For ``dtype="object"``: path string or file-like object + - For ``dtype="varchar(N)"``: string + """ + ... + + @abstractmethod + def decode(self, stored: Any, *, key: dict | None = None) -> Any: + """ + Convert stored data back to a Python object. + + Called during FETCH operations to reconstruct the original Python + object from the stored format. + + Args: + stored: Data retrieved from storage. Type depends on ``dtype``: + - For ``"object"``: an ``ObjectRef`` handle + - For blob types: the unpacked Python object + - For native types: the native Python value (str, int, etc.) + key: Primary key values as a dict. + + Returns: + The reconstructed Python object. + """ + ... + + def validate(self, value: Any) -> None: + """ + Validate a value before encoding. + + Override this method to add type checking or domain constraints. + Called automatically before ``encode()`` during INSERT operations. + The default implementation accepts any value. + + Args: + value: The value to validate. + + Raises: + TypeError: If the value has an incompatible type. + ValueError: If the value fails domain validation. + """ + pass + + def default(self) -> Any: + """ + Return a default value for this type. + + Override if the type has a sensible default value. The default + implementation raises NotImplementedError, indicating no default exists. + + Returns: + The default value for this type. + + Raises: + NotImplementedError: If no default exists (the default behavior). + """ + raise NotImplementedError(f"No default value for type <{self.type_name}>") + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}(type_name={self.type_name!r}, dtype={self.dtype!r})>" + + +def register_type(cls: type[AttributeType]) -> type[AttributeType]: + """ + Register a custom attribute type with DataJoint. + + Can be used as a decorator or called directly. The type becomes available + for use in table definitions with the ```` syntax. + + Args: + cls: An AttributeType subclass to register. + + Returns: + The same class, unmodified (allows use as decorator). + + Raises: + DataJointError: If a type with the same name is already registered + by a different class. + TypeError: If cls is not an AttributeType subclass. + + Example: + As a decorator:: + + @dj.register_type + class GraphType(dj.AttributeType): + type_name = "graph" + ... + + Or called directly:: + + dj.register_type(GraphType) + """ + if not isinstance(cls, type) or not issubclass(cls, AttributeType): + raise TypeError(f"register_type requires an AttributeType subclass, got {cls!r}") + + instance = cls() + name = instance.type_name + + if not isinstance(name, str) or not name: + raise DataJointError(f"type_name must be a non-empty string, got {name!r}") + + if name in _type_registry: + existing = _type_registry[name] + if type(existing) is not cls: + raise DataJointError( + f"Type <{name}> is already registered by " f"{type(existing).__module__}.{type(existing).__name__}" + ) + # Same class registered twice - idempotent, no error + return cls + + _type_registry[name] = instance + logger.debug(f"Registered attribute type <{name}> from {cls.__module__}.{cls.__name__}") + return cls + + +def unregister_type(name: str) -> None: + """ + Remove a type from the registry. + + Primarily useful for testing. Use with caution in production code. + + Args: + name: The type_name to unregister. + + Raises: + DataJointError: If the type is not registered. + """ + name = name.strip("<>") + if name not in _type_registry: + raise DataJointError(f"Type <{name}> is not registered") + del _type_registry[name] + + +def get_type(name: str) -> AttributeType: + """ + Retrieve a registered attribute type by name. + + Looks up the type in the explicit registry first, then attempts + to load from installed packages via entry points. + + Args: + name: The type name, with or without angle brackets. + + Returns: + The registered AttributeType instance. + + Raises: + DataJointError: If the type is not found. + """ + name = name.strip("<>") + + # Check explicit registry first + if name in _type_registry: + return _type_registry[name] + + # Lazy-load entry points + _load_entry_points() + + if name in _type_registry: + return _type_registry[name] + + raise DataJointError( + f"Unknown attribute type: <{name}>. " f"Ensure the type is registered via @dj.register_type or installed as a package." + ) + + +def list_types() -> list[str]: + """ + List all registered type names. + + Returns: + Sorted list of registered type names. + """ + _load_entry_points() + return sorted(_type_registry.keys()) + + +def is_type_registered(name: str) -> bool: + """ + Check if a type name is registered. + + Args: + name: The type name to check. + + Returns: + True if the type is registered. + """ + name = name.strip("<>") + if name in _type_registry: + return True + _load_entry_points() + return name in _type_registry + + +def _load_entry_points() -> None: + """ + Load attribute types from installed packages via entry points. + + Types are discovered from the ``datajoint.types`` entry point group. + Packages declare types in pyproject.toml:: + + [project.entry-points."datajoint.types"] + zarr_array = "dj_zarr:ZarrArrayType" + + This function is idempotent - entry points are only loaded once. + """ + global _entry_points_loaded + if _entry_points_loaded: + return + + _entry_points_loaded = True + + try: + from importlib.metadata import entry_points + except ImportError: + # Python < 3.10 fallback + try: + from importlib_metadata import entry_points + except ImportError: + logger.debug("importlib.metadata not available, skipping entry point discovery") + return + + try: + # Python 3.10+ / importlib_metadata 3.6+ + eps = entry_points(group="datajoint.types") + except TypeError: + # Older API + eps = entry_points().get("datajoint.types", []) + + for ep in eps: + if ep.name in _type_registry: + # Already registered explicitly, skip entry point + continue + try: + type_class = ep.load() + register_type(type_class) + logger.debug(f"Loaded attribute type <{ep.name}> from entry point {ep.value}") + except Exception as e: + logger.warning(f"Failed to load attribute type '{ep.name}' from {ep.value}: {e}") + + +def resolve_dtype(dtype: str, seen: set[str] | None = None) -> tuple[str, list[AttributeType]]: + """ + Resolve a dtype string, following type chains. + + If dtype references another custom type (e.g., ""), recursively + resolves to find the ultimate storage type. + + Args: + dtype: The dtype string to resolve. + seen: Set of already-seen type names (for cycle detection). + + Returns: + Tuple of (final_storage_type, list_of_types_in_chain). + The chain is ordered from outermost to innermost type. + + Raises: + DataJointError: If a circular type reference is detected. + """ + if seen is None: + seen = set() + + chain: list[AttributeType] = [] + + # Check if dtype is a custom type reference + if dtype.startswith("<") and dtype.endswith(">"): + type_name = dtype[1:-1] + + if type_name in seen: + raise DataJointError(f"Circular type reference detected: <{type_name}>") + + seen.add(type_name) + attr_type = get_type(type_name) + chain.append(attr_type) + + # Recursively resolve the inner dtype + inner_dtype, inner_chain = resolve_dtype(attr_type.dtype, seen) + chain.extend(inner_chain) + return inner_dtype, chain + + # Not a custom type - return as-is + return dtype, chain + + +# ============================================================================= +# Built-in Attribute Types +# ============================================================================= + + +class DJBlobType(AttributeType): + """ + Built-in type for DataJoint's native serialization format. + + This type handles serialization of arbitrary Python objects (including NumPy arrays, + dictionaries, lists, etc.) using DataJoint's binary blob format. The format includes: + + - Protocol headers (``mYm`` for MATLAB-compatible, ``dj0`` for Python-native) + - Optional compression (zlib) + - Support for NumPy arrays, datetime objects, UUIDs, and nested structures + + The ```` type is the explicit way to specify DataJoint's serialization. + It stores data in a MySQL ``LONGBLOB`` column. + + Example: + @schema + class ProcessedData(dj.Manual): + definition = ''' + data_id : int + --- + results : # Serialized Python objects + raw_bytes : longblob # Raw bytes (no serialization) + ''' + + Note: + Plain ``longblob`` columns store and return raw bytes without serialization. + Use ```` when you need automatic serialization of Python objects. + Existing schemas using implicit blob serialization should migrate to ```` + using ``dj.migrate.migrate_blob_columns()``. + """ + + type_name = "djblob" + dtype = "longblob" + + def encode(self, value: Any, *, key: dict | None = None) -> bytes: + """ + Serialize a Python object to DataJoint's blob format. + + Args: + value: Any serializable Python object (dict, list, numpy array, etc.) + key: Primary key values (unused for blob serialization). + + Returns: + Serialized bytes with protocol header and optional compression. + """ + from . import blob + + return blob.pack(value, compress=True) + + def decode(self, stored: bytes, *, key: dict | None = None) -> Any: + """ + Deserialize DataJoint blob format back to a Python object. + + Args: + stored: Serialized blob bytes. + key: Primary key values (unused for blob serialization). + + Returns: + The deserialized Python object. + """ + from . import blob + + return blob.unpack(stored, squeeze=False) + + +class DJBlobExternalType(AttributeType): + """ + Built-in type for externally-stored DataJoint blobs. + + Similar to ```` but stores data in external blob storage instead + of inline in the database. Useful for large objects. + + The store name is specified when defining the column type. + + Example: + @schema + class LargeData(dj.Manual): + definition = ''' + data_id : int + --- + large_array : blob@mystore # External storage with auto-serialization + ''' + """ + + # Note: This type isn't directly usable via syntax + # It's used internally when blob@store syntax is detected + type_name = "djblob_external" + dtype = "blob@store" # Placeholder - actual store is determined at declaration time + + def encode(self, value: Any, *, key: dict | None = None) -> bytes: + """Serialize a Python object to DataJoint's blob format.""" + from . import blob + + return blob.pack(value, compress=True) + + def decode(self, stored: bytes, *, key: dict | None = None) -> Any: + """Deserialize DataJoint blob format back to a Python object.""" + from . import blob + + return blob.unpack(stored, squeeze=False) + + +def _register_builtin_types() -> None: + """ + Register DataJoint's built-in attribute types. + + Called automatically during module initialization. + """ + register_type(DJBlobType) + + +# Register built-in types when module is loaded +_register_builtin_types() diff --git a/src/datajoint/autopopulate.py b/src/datajoint/autopopulate.py index 677a8113c..c90116a74 100644 --- a/src/datajoint/autopopulate.py +++ b/src/datajoint/autopopulate.py @@ -5,7 +5,6 @@ import inspect import logging import multiprocessing as mp -import random import signal import traceback @@ -13,8 +12,7 @@ from tqdm import tqdm from .errors import DataJointError, LostConnectionError -from .expression import AndList, QueryExpression -from .hash import key_hash +from .expression import AndList # noinspection PyExceptionInherit,PyCallingNonCallable @@ -55,6 +53,7 @@ class AutoPopulate: _key_source = None _allow_insert = False + _jobs_table = None # Cached JobsTable instance @property def key_source(self): @@ -74,7 +73,7 @@ def _rename_attributes(table, props): ) if self._key_source is None: - parents = self.target.parents(primary=True, as_objects=True, foreign_key_info=True) + parents = self.parents(primary=True, as_objects=True, foreign_key_info=True) if not parents: raise DataJointError("A table must have dependencies from its primary key for auto-populate to work") self._key_source = _rename_attributes(*parents[0]) @@ -152,49 +151,20 @@ def make(self, key): yield @property - def target(self): + def jobs(self): """ - :return: table to be populated. - In the typical case, dj.AutoPopulate is mixed into a dj.Table class by - inheritance and the target is self. - """ - return self + Access the jobs table for this auto-populated table. - def _job_key(self, key): - """ - :param key: they key returned for the job from the key source - :return: the dict to use to generate the job reservation hash - This method allows subclasses to control the job reservation granularity. - """ - return key + The jobs table provides per-table job queue management with rich status + tracking (pending, reserved, success, error, ignore). - def _jobs_to_do(self, restrictions): - """ - :return: the query yielding the keys to be computed (derived from self.key_source) + :return: JobsTable instance for this table """ - if self.restriction: - raise DataJointError( - "Cannot call populate on a restricted table. Instead, pass conditions to populate() as arguments." - ) - todo = self.key_source + if self._jobs_table is None: + from .jobs import JobsTable - # key_source is a QueryExpression subclass -- trigger instantiation - if inspect.isclass(todo) and issubclass(todo, QueryExpression): - todo = todo() - - if not isinstance(todo, QueryExpression): - raise DataJointError("Invalid key_source value") - - try: - # check if target lacks any attributes from the primary key of key_source - raise DataJointError( - "The populate target lacks attribute %s " - "from the primary key of key_source" - % next(name for name in todo.heading.primary_key if name not in self.target.heading) - ) - except StopIteration: - pass - return (todo & AndList(restrictions)).proj() + self._jobs_table = JobsTable(self) + return self._jobs_table def populate( self, @@ -203,12 +173,12 @@ def populate( suppress_errors=False, return_exception_objects=False, reserve_jobs=False, - order="original", - limit=None, max_calls=None, display_progress=False, processes=1, make_kwargs=None, + priority=None, + refresh=True, ): """ ``table.populate()`` calls ``table.make(key)`` for every primary key in @@ -221,8 +191,6 @@ def populate( :param suppress_errors: if True, do not terminate execution. :param return_exception_objects: return error objects instead of just error messages :param reserve_jobs: if True, reserve jobs to populate in asynchronous fashion - :param order: "original"|"reverse"|"random" - the order of execution - :param limit: if not None, check at most this many keys :param max_calls: if not None, populate at most this many keys :param display_progress: if True, report progress_bar :param processes: number of processes to use. Set to None to use all cores @@ -230,6 +198,10 @@ def populate( to be passed down to each ``make()`` call. Computation arguments should be specified within the pipeline e.g. using a `dj.Lookup` table. :type make_kwargs: dict, optional + :param priority: Only process jobs at this priority or more urgent (lower values). + Only applies when reserve_jobs=True. + :param refresh: If True and no pending jobs are found, refresh the jobs queue + before giving up. Only applies when reserve_jobs=True. :return: a dict with two keys "success_count": the count of successful ``make()`` calls in this ``populate()`` call "error_list": the error list that is filled if `suppress_errors` is True @@ -237,10 +209,10 @@ def populate( if self.connection.in_transaction: raise DataJointError("Populate cannot be called during a transaction.") - valid_order = ["original", "reverse", "random"] - if order not in valid_order: - raise DataJointError("The order argument must be one of %s" % str(valid_order)) - jobs = self.connection.schemas[self.target.database].jobs if reserve_jobs else None + if self.restriction: + raise DataJointError( + "Cannot call populate on a restricted table. " "Instead, pass conditions to populate() as arguments." + ) if reserve_jobs: # Define a signal handler for SIGTERM @@ -250,29 +222,25 @@ def handler(signum, frame): old_handler = signal.signal(signal.SIGTERM, handler) - if keys is None: - keys = (self._jobs_to_do(restrictions) - self.target).fetch("KEY", limit=limit) + error_list = [] + success_list = [] - # exclude "error", "ignore" or "reserved" jobs if reserve_jobs: - exclude_key_hashes = ( - jobs & {"table_name": self.target.table_name} & 'status in ("error", "ignore", "reserved")' - ).fetch("key_hash") - keys = [key for key in keys if key_hash(key) not in exclude_key_hashes] - - if order == "reverse": - keys.reverse() - elif order == "random": - random.shuffle(keys) + # Use jobs table for coordinated processing + keys = self.jobs.fetch_pending(limit=max_calls, priority=priority) + if not keys and refresh: + logger.debug("No pending jobs found, refreshing jobs queue") + self.jobs.refresh(*restrictions) + keys = self.jobs.fetch_pending(limit=max_calls, priority=priority) + else: + # Without job reservations: compute keys directly from key_source + if keys is None: + todo = (self.key_source & AndList(restrictions)).proj() + keys = (todo - self).fetch("KEY", limit=max_calls) logger.debug("Found %d keys to populate" % len(keys)) - - keys = keys[:max_calls] nkeys = len(keys) - error_list = [] - success_list = [] - if nkeys: processes = min(_ for _ in (processes, nkeys, mp.cpu_count()) if _) @@ -282,6 +250,8 @@ def handler(signum, frame): make_kwargs=make_kwargs, ) + jobs = self.jobs if reserve_jobs else None + if processes == 1: for key in tqdm(keys, desc=self.__class__.__name__) if display_progress else keys: status = self._populate1(key, jobs, **populate_kwargs) @@ -322,46 +292,49 @@ def handler(signum, frame): def _populate1(self, key, jobs, suppress_errors, return_exception_objects, make_kwargs=None): """ populates table for one source key, calling self.make inside a transaction. - :param jobs: the jobs table or None if not reserve_jobs + :param jobs: the jobs table (JobsTable) or None if not reserve_jobs :param key: dict specifying job to populate :param suppress_errors: bool if errors should be suppressed and returned :param return_exception_objects: if True, errors must be returned as objects :return: (key, error) when suppress_errors=True, True if successfully invoke one `make()` call, otherwise False """ - # use the legacy `_make_tuples` callback. - make = self._make_tuples if hasattr(self, "_make_tuples") else self.make + import time - if jobs is not None and not jobs.reserve(self.target.table_name, self._job_key(key)): - return False + start_time = time.time() - # if make is a generator, it transaction can be delayed until the final stage - is_generator = inspect.isgeneratorfunction(make) + # Reserve the job (per-key, before make) + if jobs is not None: + jobs.reserve(key) + + # if make is a generator, transaction can be delayed until the final stage + is_generator = inspect.isgeneratorfunction(self.make) if not is_generator: self.connection.start_transaction() - if key in self.target: # already populated + if key in self: # already populated if not is_generator: self.connection.cancel_transaction() if jobs is not None: - jobs.complete(self.target.table_name, self._job_key(key)) + # Job already done - mark complete or delete + jobs.complete(key, duration=0) return False - logger.debug(f"Making {key} -> {self.target.full_table_name}") + logger.debug(f"Making {key} -> {self.full_table_name}") self.__class__._allow_insert = True try: if not is_generator: - make(dict(key), **(make_kwargs or {})) + self.make(dict(key), **(make_kwargs or {})) else: # tripartite make - transaction is delayed until the final stage - gen = make(dict(key), **(make_kwargs or {})) + gen = self.make(dict(key), **(make_kwargs or {})) fetched_data = next(gen) fetch_hash = deepdiff.DeepHash(fetched_data, ignore_iterable_order=False)[fetched_data] computed_result = next(gen) # perform the computation # fetch and insert inside a transaction self.connection.start_transaction() - gen = make(dict(key), **(make_kwargs or {})) # restart make + gen = self.make(dict(key), **(make_kwargs or {})) # restart make fetched_data = next(gen) if ( fetch_hash != deepdiff.DeepHash(fetched_data, ignore_iterable_order=False)[fetched_data] @@ -378,15 +351,25 @@ def _populate1(self, key, jobs, suppress_errors, return_exception_objects, make_ exception=error.__class__.__name__, msg=": " + str(error) if str(error) else "", ) - logger.debug(f"Error making {key} -> {self.target.full_table_name} - {error_message}") + logger.debug(f"Error making {key} -> {self.full_table_name} - {error_message}") + + # Only log errors from inside make() - not collision errors if jobs is not None: - # show error name and error message (if any) - jobs.error( - self.target.table_name, - self._job_key(key), - error_message=error_message, - error_stack=traceback.format_exc(), - ) + from .errors import DuplicateError + + if isinstance(error, DuplicateError): + # Collision error - job reverts to pending or gets deleted + # This is not a real error, just coordination artifact + logger.debug(f"Duplicate key collision for {key}, reverting job") + # Delete the reservation, letting the job be picked up again or cleaned + (jobs & key).delete_quick() + else: + # Real error inside make() - log it + jobs.error( + key, + error_message=error_message, + error_stack=traceback.format_exc(), + ) if not suppress_errors or isinstance(error, SystemExit): raise else: @@ -394,9 +377,10 @@ def _populate1(self, key, jobs, suppress_errors, return_exception_objects, make_ return key, error if return_exception_objects else error_message else: self.connection.commit_transaction() - logger.debug(f"Success making {key} -> {self.target.full_table_name}") + duration = time.time() - start_time + logger.debug(f"Success making {key} -> {self.full_table_name}") if jobs is not None: - jobs.complete(self.target.table_name, self._job_key(key)) + jobs.complete(key, duration=duration) return True finally: self.__class__._allow_insert = False @@ -406,9 +390,9 @@ def progress(self, *restrictions, display=False): Report the progress of populating the table. :return: (remaining, total) -- numbers of tuples to be populated """ - todo = self._jobs_to_do(restrictions) + todo = (self.key_source & AndList(restrictions)).proj() total = len(todo) - remaining = len(todo - self.target) + remaining = len(todo - self) if display: logger.info( "%-20s" % self.__class__.__name__ diff --git a/src/datajoint/declare.py b/src/datajoint/declare.py index a1613d7d2..397da108b 100644 --- a/src/datajoint/declare.py +++ b/src/datajoint/declare.py @@ -489,8 +489,8 @@ def substitute_special_type(match, category, foreign_key_sql, context): "ON UPDATE RESTRICT ON DELETE RESTRICT".format(external_table_root=EXTERNAL_TABLE_ROOT, **match) ) elif category == "ADAPTED": - adapter = get_adapter(context, match["type"]) - match["type"] = adapter.attribute_type + attr_type = get_adapter(context, match["type"]) + match["type"] = attr_type.dtype category = match_type(match["type"]) if category in SPECIAL_TYPES: # recursive redefinition from user-defined datatypes. diff --git a/src/datajoint/fetch.py b/src/datajoint/fetch.py index 3dab1f38b..e1b655fc0 100644 --- a/src/datajoint/fetch.py +++ b/src/datajoint/fetch.py @@ -10,7 +10,7 @@ from datajoint.condition import Top -from . import blob, hash +from . import hash from .errors import DataJointError from .objectref import ObjectRef from .settings import config @@ -66,8 +66,9 @@ def _get(connection, attr, data, squeeze, download_path): extern = connection.schemas[attr.database].external[attr.store] if attr.is_external else None - # apply attribute adapter if present - adapt = attr.adapter.get if attr.adapter else lambda x: x + # apply custom attribute type decoder if present + def adapt(x): + return attr.adapter.decode(x, key=None) if attr.adapter else x if attr.is_filepath: return adapt(extern.download_filepath(uuid.UUID(bytes=data))[0]) @@ -100,18 +101,17 @@ def _get(connection, attr, data, squeeze, download_path): safe_write(local_filepath, data.split(b"\0", 1)[1]) return adapt(str(local_filepath)) # download file from remote store - return adapt( - uuid.UUID(bytes=data) - if attr.uuid - else ( - blob.unpack( - extern.get(uuid.UUID(bytes=data)) if attr.is_external else data, - squeeze=squeeze, - ) - if attr.is_blob - else data - ) - ) + if attr.uuid: + return adapt(uuid.UUID(bytes=data)) + elif attr.is_blob: + blob_data = extern.get(uuid.UUID(bytes=data)) if attr.is_external else data + # Adapters (like ) handle deserialization in decode() + # Without adapter, blob columns return raw bytes (no deserialization) + if attr.adapter: + return attr.adapter.decode(blob_data, key=None) + return blob_data # raw bytes + else: + return adapt(data) class Fetch: diff --git a/src/datajoint/heading.py b/src/datajoint/heading.py index f4bd57a79..cc8034cd7 100644 --- a/src/datajoint/heading.py +++ b/src/datajoint/heading.py @@ -5,7 +5,8 @@ import numpy as np -from .attribute_adapter import AttributeAdapter, get_adapter +from .attribute_adapter import get_adapter +from .attribute_type import AttributeType from .declare import ( EXTERNAL_TYPES, NATIVE_TYPES, @@ -15,6 +16,37 @@ ) from .errors import FILEPATH_FEATURE_SWITCH, DataJointError, _support_filepath_types + +class _MissingType(AttributeType): + """Placeholder for missing/unregistered attribute types. Raises error on use.""" + + def __init__(self, name: str): + self._name = name + + @property + def type_name(self) -> str: + return self._name + + @property + def dtype(self) -> str: + raise DataJointError( + f"Attribute type <{self._name}> is not registered. " + "Register it with @dj.register_type or include it in the schema context." + ) + + def encode(self, value, *, key=None): + raise DataJointError( + f"Attribute type <{self._name}> is not registered. " + "Register it with @dj.register_type or include it in the schema context." + ) + + def decode(self, stored, *, key=None): + raise DataJointError( + f"Attribute type <{self._name}> is not registered. " + "Register it with @dj.register_type or include it in the schema context." + ) + + logger = logging.getLogger(__name__.split(".")[0]) default_attribute_properties = dict( # these default values are set in computed attributes @@ -289,7 +321,7 @@ def _init_from_database(self): if special: special = special.groupdict() attr.update(special) - # process adapted attribute types + # process custom attribute types (adapted types) if special and TYPE_PATTERN["ADAPTED"].match(attr["type"]): assert context is not None, "Declaration context is not set" adapter_name = special["type"] @@ -297,15 +329,11 @@ def _init_from_database(self): attr.update(adapter=get_adapter(context, adapter_name)) except DataJointError: # if no adapter, then delay the error until the first invocation - attr.update(adapter=AttributeAdapter()) + attr.update(adapter=_MissingType(adapter_name)) else: - attr.update(type=attr["adapter"].attribute_type) + attr.update(type=attr["adapter"].dtype) if not any(r.match(attr["type"]) for r in TYPE_PATTERN.values()): - raise DataJointError( - "Invalid attribute type '{type}' in adapter object <{adapter_name}>.".format( - adapter_name=adapter_name, **attr - ) - ) + raise DataJointError(f"Invalid dtype '{attr['type']}' in attribute type <{adapter_name}>.") special = not any(TYPE_PATTERN[c].match(attr["type"]) for c in NATIVE_TYPES) if special: diff --git a/src/datajoint/jobs.py b/src/datajoint/jobs.py index ff6440495..7dff66333 100644 --- a/src/datajoint/jobs.py +++ b/src/datajoint/jobs.py @@ -1,154 +1,502 @@ +""" +Autopopulate 2.0 Jobs System + +This module implements per-table job tables for auto-populated tables. +Each dj.Imported or dj.Computed table gets its own hidden jobs table +with FK-derived primary keys and rich status tracking. +""" + +import logging import os import platform +from datetime import datetime +from typing import TYPE_CHECKING -from .errors import DuplicateError -from .hash import key_hash +from .errors import DataJointError, DuplicateError +from .expression import QueryExpression from .heading import Heading from .settings import config from .table import Table +if TYPE_CHECKING: + from .autopopulate import AutoPopulate + +logger = logging.getLogger(__name__.split(".")[0]) + ERROR_MESSAGE_LENGTH = 2047 TRUNCATION_APPENDIX = "...truncated" +# Default configuration values +DEFAULT_STALE_TIMEOUT = 3600 # 1 hour +DEFAULT_PRIORITY = 5 +DEFAULT_KEEP_COMPLETED = False -class JobTable(Table): + +class JobsTable(Table): """ - A base table with no definition. Allows reserving jobs + Per-table job queue for auto-populated tables. + + Each dj.Imported or dj.Computed table has an associated hidden jobs table + with the naming convention ~__jobs. + + The jobs table primary key includes only those attributes derived from + foreign keys in the target table's primary key. Additional primary key + attributes (if any) are excluded. + + Status values: + - pending: Job is queued and ready to be processed + - reserved: Job is currently being processed by a worker + - success: Job completed successfully + - error: Job failed with an error + - ignore: Job should be skipped (manually set) """ - def __init__(self, conn, database): - self.database = database - self._connection = conn - self._heading = Heading(table_info=dict(conn=conn, database=database, table_name=self.table_name, context=None)) + def __init__(self, target: "AutoPopulate"): + """ + Initialize a JobsTable for the given auto-populated table. + + Args: + target: The auto-populated table (dj.Imported or dj.Computed) + """ + self._target = target + self._connection = target.connection + self.database = target.database + self._user = self.connection.get_user() + + # Derive the jobs table name from the target table + # e.g., __filtered_image -> _filtered_image__jobs + target_table_name = target.table_name + if target_table_name.startswith("__"): + # Computed table: __foo -> _foo__jobs + self._table_name = f"~{target_table_name[2:]}__jobs" + elif target_table_name.startswith("_"): + # Imported table: _foo -> _foo__jobs + self._table_name = f"~{target_table_name[1:]}__jobs" + else: + # Manual/Lookup (shouldn't happen for auto-populated) + self._table_name = f"~{target_table_name}__jobs" + + # Build the definition dynamically based on target's FK-derived primary key + self._definition = self._build_definition() + + # Initialize heading + self._heading = Heading( + table_info=dict( + conn=self._connection, + database=self.database, + table_name=self.table_name, + context=None, + ) + ) self._support = [self.full_table_name] - self._definition = """ # job reservation table for `{database}` - table_name :varchar(255) # className of the table - key_hash :char(32) # key hash - --- - status :enum('reserved','error','ignore') # if tuple is missing, the job is available - key=null :blob # structure containing the key - error_message="" :varchar({error_message_length}) # error message returned if failed - error_stack=null :mediumblob # error stack if failed - user="" :varchar(255) # database user - host="" :varchar(255) # system hostname - pid=0 :int unsigned # system process id - connection_id = 0 : bigint unsigned # connection_id() - timestamp=CURRENT_TIMESTAMP :timestamp # automatic timestamp - """.format(database=database, error_message_length=ERROR_MESSAGE_LENGTH) + def _get_fk_derived_primary_key(self) -> list[tuple[str, str]]: + """ + Get the FK-derived primary key attributes from the target table. + + Returns: + List of (attribute_name, attribute_type) tuples for FK-derived PK attributes. + """ + # Get parent tables that contribute to the primary key + parents = self._target.parents(primary=True, as_objects=True, foreign_key_info=True) + + # Collect all FK-derived primary key attributes + fk_pk_attrs = set() + for parent_table, props in parents: + # attr_map maps child attr -> parent attr + for child_attr in props["attr_map"].keys(): + fk_pk_attrs.add(child_attr) + + # Get attribute definitions from target table's heading + pk_definitions = [] + for attr_name in self._target.primary_key: + if attr_name in fk_pk_attrs: + attr = self._target.heading.attributes[attr_name] + # Build attribute definition string + attr_def = f"{attr_name} : {attr.type}" + pk_definitions.append((attr_name, attr_def)) + + return pk_definitions + + def _build_definition(self) -> str: + """ + Build the table definition for the jobs table. + + Returns: + DataJoint table definition string. + """ + # Get FK-derived primary key attributes + pk_attrs = self._get_fk_derived_primary_key() + + if not pk_attrs: + raise DataJointError( + f"Cannot create jobs table for {self._target.full_table_name}: " + "no foreign-key-derived primary key attributes found." + ) + + # Build primary key section + pk_section = "\n".join(attr_def for _, attr_def in pk_attrs) + + definition = f"""# Job queue for {self._target.class_name} +{pk_section} +--- +status : enum('pending', 'reserved', 'success', 'error', 'ignore') +priority : int # Lower = more urgent (0 = highest priority) +created_time : datetime(6) # When job was added to queue +scheduled_time : datetime(6) # Process on or after this time +reserved_time=null : datetime(6) # When job was reserved +completed_time=null : datetime(6) # When job completed +duration=null : float # Execution duration in seconds +error_message="" : varchar({ERROR_MESSAGE_LENGTH}) # Error message if failed +error_stack=null : # Full error traceback +user="" : varchar(255) # Database user who reserved/completed job +host="" : varchar(255) # Hostname of worker +pid=0 : int unsigned # Process ID of worker +connection_id=0 : bigint unsigned # MySQL connection ID +version="" : varchar(255) # Code version +""" + return definition + + @property + def definition(self) -> str: + return self._definition + + @property + def table_name(self) -> str: + return self._table_name + + @property + def target(self) -> "AutoPopulate": + """The auto-populated table this jobs table is associated with.""" + return self._target + + def _ensure_declared(self) -> None: + """Ensure the jobs table is declared in the database.""" if not self.is_declared: self.declare() - self._user = self.connection.get_user() + + # --- Status filter properties --- @property - def definition(self): - return self._definition + def pending(self) -> QueryExpression: + """Return query for pending jobs.""" + self._ensure_declared() + return self & 'status="pending"' + + @property + def reserved(self) -> QueryExpression: + """Return query for reserved jobs.""" + self._ensure_declared() + return self & 'status="reserved"' + + @property + def errors(self) -> QueryExpression: + """Return query for error jobs.""" + self._ensure_declared() + return self & 'status="error"' + + @property + def ignored(self) -> QueryExpression: + """Return query for ignored jobs.""" + self._ensure_declared() + return self & 'status="ignore"' @property - def table_name(self): - return "~jobs" + def completed(self) -> QueryExpression: + """Return query for completed (success) jobs.""" + self._ensure_declared() + return self & 'status="success"' + + # --- Core methods --- - def delete(self): - """bypass interactive prompts and dependencies""" + def delete(self) -> None: + """Delete jobs without confirmation (inherits from delete_quick).""" self.delete_quick() - def drop(self): - """bypass interactive prompts and dependencies""" + def drop(self) -> None: + """Drop the jobs table without confirmation.""" self.drop_quick() - def reserve(self, table_name, key): + def refresh( + self, + *restrictions, + delay: float = 0, + priority: int = None, + stale_timeout: float = None, + ) -> dict: """ - Reserve a job for computation. When a job is reserved, the job table contains an entry for the - job key, identified by its hash. When jobs are completed, the entry is removed. + Refresh the jobs queue: add new jobs and remove stale ones. - :param table_name: `database`.`table_name` - :param key: the dict of the job's primary key - :return: True if reserved job successfully. False = the jobs is already taken + Operations performed: + 1. Add new jobs: (key_source & restrictions) - target - jobs → insert as 'pending' + 2. Remove stale jobs: pending jobs older than stale_timeout whose keys + are no longer in key_source + + Args: + restrictions: Conditions to filter key_source + delay: Seconds from now until jobs become available for processing. + Default: 0 (jobs are immediately available). + Uses database server time to avoid clock sync issues. + priority: Priority for new jobs (lower = more urgent). Default from config. + stale_timeout: Seconds after which pending jobs are checked for staleness. + Default from config. + + Returns: + {'added': int, 'removed': int} - counts of jobs added and stale jobs removed """ - job = dict( - table_name=table_name, - key_hash=key_hash(key), - status="reserved", - host=platform.node(), - pid=os.getpid(), - connection_id=self.connection.connection_id, - key=key, - user=self._user, - ) - try: - with config.override(enable_python_native_blobs=True): - self.insert1(job, ignore_extra_fields=True) - except DuplicateError: - return False - return True + self._ensure_declared() + + if priority is None: + priority = config.jobs.default_priority + if stale_timeout is None: + stale_timeout = config.jobs.stale_timeout + + # Get FK-derived primary key attribute names + pk_attrs = [name for name, _ in self._get_fk_derived_primary_key()] - def ignore(self, table_name, key): + # Step 1: Find new keys to add + # (key_source & restrictions) - target - jobs + key_source = self._target.key_source + if restrictions: + from .expression import AndList + + key_source = key_source & AndList(restrictions) + + # Project to FK-derived attributes only + key_source_proj = key_source.proj(*pk_attrs) + target_proj = self._target.proj(*pk_attrs) + existing_jobs = self.proj() # jobs table PK is the FK-derived attrs + + # Keys that need jobs: in key_source, not in target, not already in jobs + new_keys = (key_source_proj - target_proj - existing_jobs).fetch("KEY") + + # Insert new jobs + added = 0 + for key in new_keys: + try: + self._insert_job_with_delay(key, priority, delay) + added += 1 + except DuplicateError: + # Job was added by another process + pass + + # Step 2: Remove stale pending jobs + # Find pending jobs older than stale_timeout whose keys are not in key_source + removed = 0 + if stale_timeout > 0: + stale_condition = f'status="pending" AND ' f"created_time < NOW() - INTERVAL {stale_timeout} SECOND" + stale_jobs = (self & stale_condition).proj() + + # Check which stale jobs are no longer in key_source + orphaned_keys = (stale_jobs - key_source_proj).fetch("KEY") + for key in orphaned_keys: + (self & key).delete_quick() + removed += 1 + + return {"added": added, "removed": removed} + + def _insert_job_with_delay(self, key: dict, priority: int, delay: float) -> None: """ - Set a job to be ignored for computation. When a job is ignored, the job table contains an entry for the - job key, identified by its hash, with status "ignore". + Insert a new job with scheduled_time set using database server time. Args: - table_name: - Table name (str) - `database`.`table_name` - key: - The dict of the job's primary key + key: Primary key dict for the job + priority: Job priority (lower = more urgent) + delay: Seconds from now until job becomes available + """ + # Build column names and values + pk_attrs = [name for name, _ in self._get_fk_derived_primary_key()] + columns = pk_attrs + ["status", "priority", "created_time", "scheduled_time", "user", "host", "pid", "connection_id"] - Returns: - True if ignore job successfully. False = the jobs is already taken - """ - job = dict( - table_name=table_name, - key_hash=key_hash(key), - status="ignore", - host=platform.node(), - pid=os.getpid(), - connection_id=self.connection.connection_id, - key=key, - user=self._user, - ) - try: - with config.override(enable_python_native_blobs=True): - self.insert1(job, ignore_extra_fields=True) - except DuplicateError: - return False - return True + # Build values + pk_values = [f"'{key[attr]}'" if isinstance(key[attr], str) else str(key[attr]) for attr in pk_attrs] + other_values = [ + "'pending'", + str(priority), + "NOW(6)", # created_time + f"NOW(6) + INTERVAL {delay} SECOND" if delay > 0 else "NOW(6)", # scheduled_time + f"'{self._user}'", + f"'{platform.node()}'", + str(os.getpid()), + str(self.connection.connection_id), + ] - def complete(self, table_name, key): + sql = f""" + INSERT INTO {self.full_table_name} + ({', '.join(f'`{c}`' for c in columns)}) + VALUES ({', '.join(pk_values + other_values)}) """ - Log a completed job. When a job is completed, its reservation entry is deleted. + self.connection.query(sql) - :param table_name: `database`.`table_name` - :param key: the dict of the job's primary key + def reserve(self, key: dict) -> None: """ - job_key = dict(table_name=table_name, key_hash=key_hash(key)) - (self & job_key).delete_quick() + Reserve a job for processing. + + Updates the job record to 'reserved' status. The caller (populate) is + responsible for verifying the job is pending before calling this method. - def error(self, table_name, key, error_message, error_stack=None): + Args: + key: Primary key dict for the job """ - Log an error message. The job reservation is replaced with an error entry. - if an error occurs, leave an entry describing the problem + self._ensure_declared() + + pk_attrs = [name for name, _ in self._get_fk_derived_primary_key()] + job_key = {attr: key[attr] for attr in pk_attrs if attr in key} - :param table_name: `database`.`table_name` - :param key: the dict of the job's primary key - :param error_message: string error message - :param error_stack: stack trace + update_row = { + **job_key, + "status": "reserved", + "reserved_time": datetime.now(), + "user": self._user, + "host": platform.node(), + "pid": os.getpid(), + "connection_id": self.connection.connection_id, + } + self.update1(update_row) + + def complete(self, key: dict, duration: float = None, keep: bool = None) -> None: + """ + Mark a job as successfully completed. + + Args: + key: Primary key dict for the job + duration: Execution duration in seconds + keep: If True, mark as 'success'. If False, delete the job entry. + Default from config (jobs.keep_completed). """ + self._ensure_declared() + + if keep is None: + keep = config.jobs.keep_completed + + pk_attrs = [name for name, _ in self._get_fk_derived_primary_key()] + job_key = {attr: key[attr] for attr in pk_attrs if attr in key} + + if keep: + # Update to success status + update_row = { + **job_key, + "status": "success", + "completed_time": datetime.now(), + } + if duration is not None: + update_row["duration"] = duration + self.update1(update_row) + else: + # Delete the job entry + (self & job_key).delete_quick() + + def error(self, key: dict, error_message: str, error_stack: str = None) -> None: + """ + Mark a job as failed with error details. + + Args: + key: Primary key dict for the job + error_message: Error message string + error_stack: Full stack trace + """ + self._ensure_declared() + + # Truncate error message if necessary if len(error_message) > ERROR_MESSAGE_LENGTH: error_message = error_message[: ERROR_MESSAGE_LENGTH - len(TRUNCATION_APPENDIX)] + TRUNCATION_APPENDIX - with config.override(enable_python_native_blobs=True): - self.insert1( - dict( - table_name=table_name, - key_hash=key_hash(key), - status="error", - host=platform.node(), - pid=os.getpid(), - connection_id=self.connection.connection_id, - user=self._user, - key=key, - error_message=error_message, - error_stack=error_stack, - ), - replace=True, - ignore_extra_fields=True, - ) + + pk_attrs = [name for name, _ in self._get_fk_derived_primary_key()] + job_key = {attr: key[attr] for attr in pk_attrs if attr in key} + + # Build update dict with all required fields + update_row = { + **job_key, + "status": "error", + "completed_time": datetime.now(), + "error_message": error_message, + } + if error_stack is not None: + update_row["error_stack"] = error_stack + + self.update1(update_row) + + def ignore(self, key: dict) -> None: + """ + Mark a key to be ignored (skipped during populate). + + Only inserts new records. Existing job entries cannot be converted to + ignore status - they must be cleared first. + + Args: + key: Primary key dict for the job + """ + self._ensure_declared() + + pk_attrs = [name for name, _ in self._get_fk_derived_primary_key()] + job_key = {attr: key[attr] for attr in pk_attrs if attr in key} + + try: + self._insert_job_with_status(job_key, "ignore") + except DuplicateError: + pass # Already tracked + + def _insert_job_with_status(self, key: dict, status: str) -> None: + """Insert a new job with the given status.""" + now = datetime.now() + row = { + **key, + "status": status, + "priority": DEFAULT_PRIORITY, + "created_time": now, + "scheduled_time": now, + "user": self._user, + "host": platform.node(), + "pid": os.getpid(), + "connection_id": self.connection.connection_id, + } + self.insert1(row) + + def progress(self) -> dict: + """ + Report detailed progress of job processing. + + Returns: + Dict with counts for each status and total. + """ + self._ensure_declared() + + result = { + "pending": len(self.pending), + "reserved": len(self.reserved), + "success": len(self.completed), + "error": len(self.errors), + "ignore": len(self.ignored), + } + result["total"] = sum(result.values()) + return result + + def fetch_pending( + self, + limit: int = None, + priority: int = None, + ) -> list[dict]: + """ + Fetch pending jobs ordered by priority and scheduled time. + + Args: + limit: Maximum number of jobs to fetch + priority: Only fetch jobs at this priority or more urgent (lower values) + + Returns: + List of job key dicts + """ + self._ensure_declared() + + # Build query for non-stale pending jobs + query = self & 'status="pending" AND scheduled_time <= NOW(6)' + + if priority is not None: + query = query & f"priority <= {priority}" + + # Fetch with ordering + return query.fetch( + "KEY", + order_by=["priority ASC", "scheduled_time ASC"], + limit=limit, + ) diff --git a/src/datajoint/migrate.py b/src/datajoint/migrate.py new file mode 100644 index 000000000..696ca380e --- /dev/null +++ b/src/datajoint/migrate.py @@ -0,0 +1,250 @@ +""" +Migration utilities for DataJoint schema updates. + +This module provides tools for migrating existing schemas to use the new +AttributeType system, particularly for upgrading blob columns to use +explicit `` type declarations. +""" + +from __future__ import annotations + +import logging +import re +from typing import TYPE_CHECKING + +from .errors import DataJointError + +if TYPE_CHECKING: + from .schemas import Schema + +logger = logging.getLogger(__name__.split(".")[0]) + +# Pattern to detect blob types +BLOB_TYPES = re.compile(r"^(tiny|small|medium|long|)blob$", re.I) + + +def analyze_blob_columns(schema: Schema) -> list[dict]: + """ + Analyze a schema to find blob columns that could be migrated to . + + This function identifies blob columns that: + 1. Have a MySQL blob type (tinyblob, blob, mediumblob, longblob) + 2. Do NOT already have an adapter/type specified in their comment + + All blob size variants are included in the analysis. + + Args: + schema: The DataJoint schema to analyze. + + Returns: + List of dicts with keys: + - table_name: Full table name (database.table) + - column_name: Name of the blob column + - column_type: MySQL column type (tinyblob, blob, mediumblob, longblob) + - current_comment: Current column comment + - needs_migration: True if column should be migrated + + Example: + >>> import datajoint as dj + >>> schema = dj.schema('my_database') + >>> columns = dj.migrate.analyze_blob_columns(schema) + >>> for col in columns: + ... if col['needs_migration']: + ... print(f"{col['table_name']}.{col['column_name']} ({col['column_type']})") + """ + results = [] + + connection = schema.connection + + # Get all tables in the schema + tables_query = """ + SELECT TABLE_NAME + FROM information_schema.TABLES + WHERE TABLE_SCHEMA = %s + AND TABLE_TYPE = 'BASE TABLE' + AND TABLE_NAME NOT LIKE '~%%' + """ + + tables = connection.query(tables_query, args=(schema.database,)).fetchall() + + for (table_name,) in tables: + # Get column information for each table + columns_query = """ + SELECT COLUMN_NAME, COLUMN_TYPE, COLUMN_COMMENT + FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA = %s + AND TABLE_NAME = %s + AND DATA_TYPE IN ('tinyblob', 'blob', 'mediumblob', 'longblob') + """ + + columns = connection.query(columns_query, args=(schema.database, table_name)).fetchall() + + for column_name, column_type, comment in columns: + # Check if comment already has an adapter type (starts with :type:) + has_adapter = comment and comment.startswith(":") + + results.append( + { + "table_name": f"{schema.database}.{table_name}", + "column_name": column_name, + "column_type": column_type, + "current_comment": comment or "", + "needs_migration": not has_adapter, + } + ) + + return results + + +def generate_migration_sql( + schema: Schema, + target_type: str = "djblob", + dry_run: bool = True, +) -> list[str]: + """ + Generate SQL statements to migrate blob columns to use . + + This generates ALTER TABLE statements that update column comments to + include the `::` prefix, marking them as using explicit + DataJoint blob serialization. + + Args: + schema: The DataJoint schema to migrate. + target_type: The type name to migrate to (default: "djblob"). + dry_run: If True, only return SQL without executing. + + Returns: + List of SQL ALTER TABLE statements. + + Example: + >>> sql_statements = dj.migrate.generate_migration_sql(schema) + >>> for sql in sql_statements: + ... print(sql) + + Note: + This is a metadata-only migration. The actual blob data format + remains unchanged - only the column comments are updated to + indicate explicit type handling. + """ + columns = analyze_blob_columns(schema) + sql_statements = [] + + for col in columns: + if not col["needs_migration"]: + continue + + # Build new comment with type prefix + old_comment = col["current_comment"] + new_comment = f":<{target_type}>:{old_comment}" + + # Escape special characters for SQL + new_comment_escaped = new_comment.replace("\\", "\\\\").replace("'", "\\'") + + # Parse table name + db_name, table_name = col["table_name"].split(".") + + # Generate ALTER TABLE statement + sql = ( + f"ALTER TABLE `{db_name}`.`{table_name}` " + f"MODIFY COLUMN `{col['column_name']}` {col['column_type']} " + f"COMMENT '{new_comment_escaped}'" + ) + sql_statements.append(sql) + + return sql_statements + + +def migrate_blob_columns( + schema: Schema, + target_type: str = "djblob", + dry_run: bool = True, +) -> dict: + """ + Migrate blob columns in a schema to use explicit type. + + This updates column comments in the database to include the type + declaration. The data format remains unchanged. + + Args: + schema: The DataJoint schema to migrate. + target_type: The type name to migrate to (default: "djblob"). + dry_run: If True, only preview changes without applying. + + Returns: + Dict with keys: + - analyzed: Number of blob columns analyzed + - needs_migration: Number of columns that need migration + - migrated: Number of columns migrated (0 if dry_run) + - sql_statements: List of SQL statements (executed or to be executed) + + Example: + >>> # Preview migration + >>> result = dj.migrate.migrate_blob_columns(schema, dry_run=True) + >>> print(f"Would migrate {result['needs_migration']} columns") + + >>> # Apply migration + >>> result = dj.migrate.migrate_blob_columns(schema, dry_run=False) + >>> print(f"Migrated {result['migrated']} columns") + + Warning: + After migration, table definitions should be updated to use + `` instead of `longblob` for consistency. The migration + only updates database metadata; source code changes are manual. + """ + columns = analyze_blob_columns(schema) + sql_statements = generate_migration_sql(schema, target_type=target_type) + + result = { + "analyzed": len(columns), + "needs_migration": sum(1 for c in columns if c["needs_migration"]), + "migrated": 0, + "sql_statements": sql_statements, + } + + if dry_run: + logger.info(f"Dry run: would migrate {result['needs_migration']} columns") + for sql in sql_statements: + logger.info(f" {sql}") + return result + + # Execute migrations + connection = schema.connection + for sql in sql_statements: + try: + connection.query(sql) + result["migrated"] += 1 + logger.info(f"Executed: {sql}") + except Exception as e: + logger.error(f"Failed to execute: {sql}\nError: {e}") + raise DataJointError(f"Migration failed: {e}") from e + + logger.info(f"Successfully migrated {result['migrated']} columns") + return result + + +def check_migration_status(schema: Schema) -> dict: + """ + Check the migration status of blob columns in a schema. + + Args: + schema: The DataJoint schema to check. + + Returns: + Dict with keys: + - total_blob_columns: Total number of blob columns + - migrated: Number of columns with explicit type + - pending: Number of columns using implicit serialization + - columns: List of column details + + Example: + >>> status = dj.migrate.check_migration_status(schema) + >>> print(f"Migration progress: {status['migrated']}/{status['total_blob_columns']}") + """ + columns = analyze_blob_columns(schema) + + return { + "total_blob_columns": len(columns), + "migrated": sum(1 for c in columns if not c["needs_migration"]), + "pending": sum(1 for c in columns if c["needs_migration"]), + "columns": columns, + } diff --git a/src/datajoint/schemas.py b/src/datajoint/schemas.py index e9b83efff..9df3ba34d 100644 --- a/src/datajoint/schemas.py +++ b/src/datajoint/schemas.py @@ -10,7 +10,6 @@ from .errors import AccessError, DataJointError from .external import ExternalMapping from .heading import Heading -from .jobs import JobTable from .settings import config from .table import FreeTable, Log, lookup_class_name from .user_tables import Computed, Imported, Lookup, Manual, Part, _get_tier @@ -70,7 +69,7 @@ def __init__( self.context = context self.create_schema = create_schema self.create_tables = create_tables - self._jobs = None + self._auto_populated_tables = [] # Track auto-populated table classes self.external = ExternalMapping(self) self.add_objects = add_objects self.declare_list = [] @@ -227,6 +226,11 @@ def _decorate_table(self, table_class, context, assert_declared=False): else: instance.insert(contents, skip_duplicates=True) + # Track auto-populated tables for schema.jobs + if isinstance(instance, (Imported, Computed)) and not isinstance(instance, Part): + if table_class not in self._auto_populated_tables: + self._auto_populated_tables.append(table_class) + @property def log(self): self._assert_exists() @@ -338,14 +342,15 @@ def exists(self): @property def jobs(self): """ - schema.jobs provides a view of the job reservation table for the schema + Access job tables for all auto-populated tables in the schema. + + Returns a list of JobsTable objects, one for each Imported or Computed + table in the schema. - :return: jobs table + :return: list of JobsTable objects """ self._assert_exists() - if self._jobs is None: - self._jobs = JobTable(self.connection, self.database) - return self._jobs + return [table_class().jobs for table_class in self._auto_populated_tables] @property def code(self): diff --git a/src/datajoint/settings.py b/src/datajoint/settings.py index a27f3a004..d83d11efd 100644 --- a/src/datajoint/settings.py +++ b/src/datajoint/settings.py @@ -188,6 +188,22 @@ class ExternalSettings(BaseSettings): aws_secret_access_key: SecretStr | None = Field(default=None, validation_alias="DJ_AWS_SECRET_ACCESS_KEY") +class JobsSettings(BaseSettings): + """Job queue settings for auto-populated tables.""" + + model_config = SettingsConfigDict( + env_prefix="DJ_JOBS_", + case_sensitive=False, + extra="forbid", + validate_assignment=True, + ) + + auto_refresh: bool = Field(default=True, description="Auto-refresh on populate") + keep_completed: bool = Field(default=False, description="Keep success records in jobs table") + stale_timeout: int = Field(default=3600, description="Seconds before pending job is considered stale") + default_priority: int = Field(default=5, description="Default priority for new jobs (lower = more urgent)") + + class ObjectStorageSettings(BaseSettings): """Object storage configuration for the object type.""" @@ -250,6 +266,7 @@ class Config(BaseSettings): connection: ConnectionSettings = Field(default_factory=ConnectionSettings) display: DisplaySettings = Field(default_factory=DisplaySettings) external: ExternalSettings = Field(default_factory=ExternalSettings) + jobs: JobsSettings = Field(default_factory=JobsSettings) object_storage: ObjectStorageSettings = Field(default_factory=ObjectStorageSettings) # Top-level settings diff --git a/src/datajoint/table.py b/src/datajoint/table.py index 356c538ed..02374b9ff 100644 --- a/src/datajoint/table.py +++ b/src/datajoint/table.py @@ -14,7 +14,6 @@ import numpy as np import pandas -from . import blob from .condition import make_condition from .declare import alter, declare from .errors import ( @@ -934,7 +933,9 @@ def __make_placeholder(self, name, value, ignore_extra_fields=False, row=None): return None attr = self.heading[name] if attr.adapter: - value = attr.adapter.put(value) + # Custom attribute type: validate and encode + attr.adapter.validate(value) + value = attr.adapter.encode(value, key=None) if value is None or (attr.numeric and (value == "" or np.isnan(float(value)))): # set default value placeholder, value = "DEFAULT", None @@ -948,8 +949,10 @@ def __make_placeholder(self, name, value, ignore_extra_fields=False, row=None): raise DataJointError("badly formed UUID value {v} for attribute `{n}`".format(v=value, n=name)) value = value.bytes elif attr.is_blob: - value = blob.pack(value) - value = self.external[attr.store].put(value).bytes if attr.is_external else value + # Adapters (like ) handle serialization in encode() + # Without adapter, blob columns store raw bytes (no serialization) + if attr.is_external: + value = self.external[attr.store].put(value).bytes elif attr.is_attachment: attachment_path = Path(value) if attr.is_external: diff --git a/src/datajoint/user_tables.py b/src/datajoint/user_tables.py index d7faeb285..59065e7f1 100644 --- a/src/datajoint/user_tables.py +++ b/src/datajoint/user_tables.py @@ -152,6 +152,15 @@ class Imported(UserTable, AutoPopulate): _prefix = "_" tier_regexp = r"(?P" + _prefix + _base_regexp + ")" + def drop_quick(self): + """ + Drop the table and its associated jobs table. + """ + # Drop the jobs table first if it exists + if self._jobs_table is not None and self._jobs_table.is_declared: + self._jobs_table.drop_quick() + super().drop_quick() + class Computed(UserTable, AutoPopulate): """ @@ -162,6 +171,15 @@ class Computed(UserTable, AutoPopulate): _prefix = "__" tier_regexp = r"(?P" + _prefix + _base_regexp + ")" + def drop_quick(self): + """ + Drop the table and its associated jobs table. + """ + # Drop the jobs table first if it exists + if self._jobs_table is not None and self._jobs_table.is_declared: + self._jobs_table.drop_quick() + super().drop_quick() + class Part(UserTable): """ diff --git a/tests/conftest.py b/tests/conftest.py index c2f2a5ae9..23222f43a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,9 +16,7 @@ import datajoint as dj from datajoint.errors import ( - ADAPTED_TYPE_SWITCH, FILEPATH_FEATURE_SWITCH, - DataJointError, ) from . import schema, schema_advanced, schema_external, schema_object, schema_simple @@ -56,21 +54,6 @@ def clean_autopopulate(experiment, trial, ephys): experiment.delete() -@pytest.fixture -def clean_jobs(schema_any): - """ - Explicit cleanup fixture for jobs tests. - - Cleans jobs table before test runs. - Tests must explicitly request this fixture to get cleanup. - """ - try: - schema_any.jobs.delete() - except DataJointError: - pass - yield - - @pytest.fixture def clean_test_tables(test, test_extra, test_no_extra): """ @@ -334,10 +317,14 @@ def monkeymodule(): @pytest.fixture -def enable_adapted_types(monkeypatch): - monkeypatch.setenv(ADAPTED_TYPE_SWITCH, "TRUE") +def enable_adapted_types(): + """ + Deprecated fixture - custom attribute types no longer require a feature flag. + + This fixture is kept for backward compatibility but does nothing. + Custom types are now enabled by default via the AttributeType system. + """ yield - monkeypatch.delenv(ADAPTED_TYPE_SWITCH, raising=True) @pytest.fixture @@ -566,10 +553,6 @@ def mock_cache(tmpdir_factory): def schema_any(connection_test, prefix): schema_any = dj.Schema(prefix + "_test1", schema.LOCALS_ANY, connection=connection_test) assert schema.LOCALS_ANY, "LOCALS_ANY is empty" - try: - schema_any.jobs.delete() - except DataJointError: - pass schema_any(schema.TTest) schema_any(schema.TTest2) schema_any(schema.TTest3) @@ -609,10 +592,6 @@ def schema_any(connection_test, prefix): schema_any(schema.Stimulus) schema_any(schema.Longblob) yield schema_any - try: - schema_any.jobs.delete() - except DataJointError: - pass schema_any.drop() @@ -621,10 +600,6 @@ def schema_any_fresh(connection_test, prefix): """Function-scoped schema_any for tests that need fresh schema state.""" schema_any = dj.Schema(prefix + "_test1_fresh", schema.LOCALS_ANY, connection=connection_test) assert schema.LOCALS_ANY, "LOCALS_ANY is empty" - try: - schema_any.jobs.delete() - except DataJointError: - pass schema_any(schema.TTest) schema_any(schema.TTest2) schema_any(schema.TTest3) @@ -664,10 +639,6 @@ def schema_any_fresh(connection_test, prefix): schema_any(schema.Stimulus) schema_any(schema.Longblob) yield schema_any - try: - schema_any.jobs.delete() - except DataJointError: - pass schema_any.drop() diff --git a/tests/test_adapted_attributes.py b/tests/test_adapted_attributes.py index 1060a50ed..0b4285ffb 100644 --- a/tests/test_adapted_attributes.py +++ b/tests/test_adapted_attributes.py @@ -1,3 +1,10 @@ +""" +Tests for adapted/custom attribute types. + +These tests use the legacy AttributeAdapter API for backward compatibility testing. +""" + +import warnings from itertools import zip_longest import networkx as nx @@ -8,6 +15,9 @@ from . import schema_adapted from .schema_adapted import Connectivity, Layout +# Filter deprecation warnings from legacy AttributeAdapter usage in these tests +pytestmark = pytest.mark.filterwarnings("ignore::DeprecationWarning") + @pytest.fixture def schema_name(prefix): @@ -16,24 +26,28 @@ def schema_name(prefix): @pytest.fixture def adapted_graph_instance(): - yield schema_adapted.GraphAdapter() + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + yield schema_adapted.GraphAdapter() @pytest.fixture def schema_ad( connection_test, adapted_graph_instance, - enable_adapted_types, enable_filepath_feature, s3_creds, tmpdir, schema_name, ): dj.config["stores"] = {"repo-s3": dict(s3_creds, protocol="s3", location="adapted/repo", stage=str(tmpdir))} + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + layout_adapter = schema_adapted.LayoutToFilepath() context = { **schema_adapted.LOCALS_ADAPTED, "graph": adapted_graph_instance, - "layout_to_filepath": schema_adapted.LayoutToFilepath(), + "layout_to_filepath": layout_adapter, } schema = dj.schema(schema_name, context=context, connection=connection_test) schema(schema_adapted.Connectivity) @@ -92,7 +106,7 @@ def test_adapted_filepath_type(schema_ad, minio_client): c.delete() -def test_adapted_spawned(local_schema, enable_adapted_types): +def test_adapted_spawned(local_schema): c = Connectivity() # a spawned class graphs = [ nx.lollipop_graph(4, 2), diff --git a/tests/test_attribute_type.py b/tests/test_attribute_type.py new file mode 100644 index 000000000..f8f822a60 --- /dev/null +++ b/tests/test_attribute_type.py @@ -0,0 +1,419 @@ +""" +Tests for the new AttributeType system. +""" + +import pytest + +import datajoint as dj +from datajoint.attribute_type import ( + AttributeType, + _type_registry, + get_type, + is_type_registered, + list_types, + register_type, + resolve_dtype, + unregister_type, +) +from datajoint.errors import DataJointError + + +class TestAttributeTypeRegistry: + """Tests for the type registry functionality.""" + + def setup_method(self): + """Clear any test types from registry before each test.""" + for name in list(_type_registry.keys()): + if name.startswith("test_"): + del _type_registry[name] + + def teardown_method(self): + """Clean up test types after each test.""" + for name in list(_type_registry.keys()): + if name.startswith("test_"): + del _type_registry[name] + + def test_register_type_decorator(self): + """Test registering a type using the decorator.""" + + @register_type + class TestType(AttributeType): + type_name = "test_decorator" + dtype = "longblob" + + def encode(self, value, *, key=None): + return value + + def decode(self, stored, *, key=None): + return stored + + assert is_type_registered("test_decorator") + assert get_type("test_decorator").type_name == "test_decorator" + + def test_register_type_direct(self): + """Test registering a type by calling register_type directly.""" + + class TestType(AttributeType): + type_name = "test_direct" + dtype = "varchar(255)" + + def encode(self, value, *, key=None): + return str(value) + + def decode(self, stored, *, key=None): + return stored + + register_type(TestType) + assert is_type_registered("test_direct") + + def test_register_type_idempotent(self): + """Test that registering the same type twice is idempotent.""" + + @register_type + class TestType(AttributeType): + type_name = "test_idempotent" + dtype = "int" + + def encode(self, value, *, key=None): + return value + + def decode(self, stored, *, key=None): + return stored + + # Second registration should not raise + register_type(TestType) + assert is_type_registered("test_idempotent") + + def test_register_duplicate_name_different_class(self): + """Test that registering different classes with same name raises error.""" + + @register_type + class TestType1(AttributeType): + type_name = "test_duplicate" + dtype = "int" + + def encode(self, value, *, key=None): + return value + + def decode(self, stored, *, key=None): + return stored + + class TestType2(AttributeType): + type_name = "test_duplicate" + dtype = "varchar(100)" + + def encode(self, value, *, key=None): + return str(value) + + def decode(self, stored, *, key=None): + return stored + + with pytest.raises(DataJointError, match="already registered"): + register_type(TestType2) + + def test_unregister_type(self): + """Test unregistering a type.""" + + @register_type + class TestType(AttributeType): + type_name = "test_unregister" + dtype = "int" + + def encode(self, value, *, key=None): + return value + + def decode(self, stored, *, key=None): + return stored + + assert is_type_registered("test_unregister") + unregister_type("test_unregister") + assert not is_type_registered("test_unregister") + + def test_get_type_not_found(self): + """Test that getting an unregistered type raises error.""" + with pytest.raises(DataJointError, match="Unknown attribute type"): + get_type("nonexistent_type") + + def test_list_types(self): + """Test listing registered types.""" + + @register_type + class TestType(AttributeType): + type_name = "test_list" + dtype = "int" + + def encode(self, value, *, key=None): + return value + + def decode(self, stored, *, key=None): + return stored + + types = list_types() + assert "test_list" in types + assert types == sorted(types) # Should be sorted + + def test_get_type_strips_brackets(self): + """Test that get_type accepts names with or without angle brackets.""" + + @register_type + class TestType(AttributeType): + type_name = "test_brackets" + dtype = "int" + + def encode(self, value, *, key=None): + return value + + def decode(self, stored, *, key=None): + return stored + + assert get_type("test_brackets") is get_type("") + + +class TestAttributeTypeValidation: + """Tests for the validate method.""" + + def setup_method(self): + for name in list(_type_registry.keys()): + if name.startswith("test_"): + del _type_registry[name] + + def teardown_method(self): + for name in list(_type_registry.keys()): + if name.startswith("test_"): + del _type_registry[name] + + def test_validate_called_default(self): + """Test that default validate accepts any value.""" + + @register_type + class TestType(AttributeType): + type_name = "test_validate_default" + dtype = "longblob" + + def encode(self, value, *, key=None): + return value + + def decode(self, stored, *, key=None): + return stored + + t = get_type("test_validate_default") + # Default validate should not raise for any value + t.validate(None) + t.validate(42) + t.validate("string") + t.validate([1, 2, 3]) + + def test_validate_custom(self): + """Test custom validation logic.""" + + @register_type + class PositiveIntType(AttributeType): + type_name = "test_positive_int" + dtype = "int" + + def encode(self, value, *, key=None): + return value + + def decode(self, stored, *, key=None): + return stored + + def validate(self, value): + if not isinstance(value, int): + raise TypeError(f"Expected int, got {type(value).__name__}") + if value < 0: + raise ValueError("Value must be positive") + + t = get_type("test_positive_int") + t.validate(42) # Should pass + + with pytest.raises(TypeError): + t.validate("not an int") + + with pytest.raises(ValueError): + t.validate(-1) + + +class TestTypeChaining: + """Tests for type chaining (dtype referencing another custom type).""" + + def setup_method(self): + for name in list(_type_registry.keys()): + if name.startswith("test_"): + del _type_registry[name] + + def teardown_method(self): + for name in list(_type_registry.keys()): + if name.startswith("test_"): + del _type_registry[name] + + def test_resolve_native_dtype(self): + """Test resolving a native dtype.""" + final_dtype, chain = resolve_dtype("longblob") + assert final_dtype == "longblob" + assert chain == [] + + def test_resolve_custom_dtype(self): + """Test resolving a custom dtype.""" + + @register_type + class TestType(AttributeType): + type_name = "test_resolve" + dtype = "varchar(100)" + + def encode(self, value, *, key=None): + return value + + def decode(self, stored, *, key=None): + return stored + + final_dtype, chain = resolve_dtype("") + assert final_dtype == "varchar(100)" + assert len(chain) == 1 + assert chain[0].type_name == "test_resolve" + + def test_resolve_chained_dtype(self): + """Test resolving a chained dtype.""" + + @register_type + class InnerType(AttributeType): + type_name = "test_inner" + dtype = "longblob" + + def encode(self, value, *, key=None): + return value + + def decode(self, stored, *, key=None): + return stored + + @register_type + class OuterType(AttributeType): + type_name = "test_outer" + dtype = "" + + def encode(self, value, *, key=None): + return value + + def decode(self, stored, *, key=None): + return stored + + final_dtype, chain = resolve_dtype("") + assert final_dtype == "longblob" + assert len(chain) == 2 + assert chain[0].type_name == "test_outer" + assert chain[1].type_name == "test_inner" + + def test_circular_reference_detection(self): + """Test that circular type references are detected.""" + + @register_type + class TypeA(AttributeType): + type_name = "test_circular_a" + dtype = "" + + def encode(self, value, *, key=None): + return value + + def decode(self, stored, *, key=None): + return stored + + @register_type + class TypeB(AttributeType): + type_name = "test_circular_b" + dtype = "" + + def encode(self, value, *, key=None): + return value + + def decode(self, stored, *, key=None): + return stored + + with pytest.raises(DataJointError, match="Circular type reference"): + resolve_dtype("") + + +class TestExportsAndAPI: + """Test that the public API is properly exported.""" + + def test_exports_from_datajoint(self): + """Test that AttributeType and helpers are exported from datajoint.""" + assert hasattr(dj, "AttributeType") + assert hasattr(dj, "register_type") + assert hasattr(dj, "list_types") + + def test_attribute_adapter_deprecated(self): + """Test that AttributeAdapter is still available but deprecated.""" + assert hasattr(dj, "AttributeAdapter") + # AttributeAdapter should be a subclass of AttributeType + assert issubclass(dj.AttributeAdapter, dj.AttributeType) + + +class TestDJBlobType: + """Tests for the built-in DJBlobType.""" + + def test_djblob_is_registered(self): + """Test that djblob is automatically registered.""" + assert is_type_registered("djblob") + + def test_djblob_properties(self): + """Test DJBlobType properties.""" + blob_type = get_type("djblob") + assert blob_type.type_name == "djblob" + assert blob_type.dtype == "longblob" + + def test_djblob_encode_decode_roundtrip(self): + """Test that encode/decode is a proper roundtrip.""" + import numpy as np + + blob_type = get_type("djblob") + + # Test with various data types + test_data = [ + {"key": "value", "number": 42}, + [1, 2, 3, 4, 5], + np.array([1.0, 2.0, 3.0]), + "simple string", + (1, 2, 3), + None, + ] + + for original in test_data: + encoded = blob_type.encode(original) + assert isinstance(encoded, bytes) + decoded = blob_type.decode(encoded) + if isinstance(original, np.ndarray): + np.testing.assert_array_equal(decoded, original) + else: + assert decoded == original + + def test_djblob_encode_produces_valid_blob_format(self): + """Test that encoded data has valid blob protocol header.""" + blob_type = get_type("djblob") + encoded = blob_type.encode({"test": "data"}) + + # Should start with compression prefix or protocol header + valid_prefixes = (b"ZL123\0", b"mYm\0", b"dj0\0") + assert any(encoded.startswith(p) for p in valid_prefixes) + + def test_djblob_in_list_types(self): + """Test that djblob appears in list_types.""" + types = list_types() + assert "djblob" in types + + def test_djblob_handles_serialization(self): + """Test that DJBlobType handles serialization internally. + + With the new design: + - Plain longblob columns store/return raw bytes (no serialization) + - handles pack/unpack in encode/decode + - Legacy AttributeAdapter handles pack/unpack internally for backward compat + """ + blob_type = get_type("djblob") + + # DJBlobType.encode() should produce packed bytes + data = {"key": "value"} + encoded = blob_type.encode(data) + assert isinstance(encoded, bytes) + + # DJBlobType.decode() should unpack back to original + decoded = blob_type.decode(encoded) + assert decoded == data diff --git a/tests/test_autopopulate.py b/tests/test_autopopulate.py index b22b252ee..1f1d33a84 100644 --- a/tests/test_autopopulate.py +++ b/tests/test_autopopulate.py @@ -61,17 +61,22 @@ def test_populate_key_list(clean_autopopulate, subject, experiment, trial): assert n == ret["success_count"] -def test_populate_exclude_error_and_ignore_jobs(clean_autopopulate, schema_any, subject, experiment): +def test_populate_exclude_error_and_ignore_jobs(clean_autopopulate, subject, experiment): # test simple populate assert subject, "root tables are empty" assert not experiment, "table already filled?" + # Ensure jobs table is set up by refreshing + jobs = experiment.jobs + jobs.refresh() + keys = experiment.key_source.fetch("KEY", limit=2) for idx, key in enumerate(keys): if idx == 0: - schema_any.jobs.ignore(experiment.table_name, key) + jobs.ignore(key) else: - schema_any.jobs.error(experiment.table_name, key, "") + jobs.reserve(key) + jobs.error(key, error_message="Test error") experiment.populate(reserve_jobs=True) assert len(experiment.key_source & experiment) == len(experiment.key_source) - 2 diff --git a/tests/test_jobs.py b/tests/test_jobs.py index 4ffc431fe..1925eb4b5 100644 --- a/tests/test_jobs.py +++ b/tests/test_jobs.py @@ -1,130 +1,398 @@ +""" +Tests for the Autopopulate 2.0 per-table jobs system. +""" + import random import string - import datajoint as dj -from datajoint.jobs import ERROR_MESSAGE_LENGTH, TRUNCATION_APPENDIX +from datajoint.jobs import JobsTable, ERROR_MESSAGE_LENGTH, TRUNCATION_APPENDIX from . import schema -def test_reserve_job(clean_jobs, subject, schema_any): - assert subject - table_name = "fake_table" +class TestJobsTableStructure: + """Tests for JobsTable structure and initialization.""" + + def test_jobs_property_exists(self, schema_any): + """Test that Computed tables have a jobs property.""" + assert hasattr(schema.SigIntTable, "jobs") + jobs = schema.SigIntTable().jobs + assert isinstance(jobs, JobsTable) + + def test_jobs_table_name(self, schema_any): + """Test that jobs table has correct naming convention.""" + jobs = schema.SigIntTable().jobs + # SigIntTable is __sig_int_table, jobs should be ~sig_int_table__jobs + assert jobs.table_name.startswith("~") + assert jobs.table_name.endswith("__jobs") + + def test_jobs_table_primary_key(self, schema_any): + """Test that jobs table has FK-derived primary key.""" + jobs = schema.SigIntTable().jobs + # SigIntTable depends on SimpleSource with pk 'id' + assert "id" in jobs.primary_key + + def test_jobs_table_status_column(self, schema_any): + """Test that jobs table has status column with correct enum values.""" + jobs = schema.SigIntTable().jobs + jobs._ensure_declared() + status_attr = jobs.heading.attributes["status"] + assert "pending" in status_attr.type + assert "reserved" in status_attr.type + assert "success" in status_attr.type + assert "error" in status_attr.type + assert "ignore" in status_attr.type + + +class TestJobsRefresh: + """Tests for JobsTable.refresh() method.""" + + def test_refresh_adds_jobs(self, schema_any): + """Test that refresh() adds pending jobs for keys in key_source.""" + table = schema.SigIntTable() + jobs = table.jobs + jobs.delete() # Clear any existing jobs + + result = jobs.refresh() + assert result["added"] > 0 + assert len(jobs.pending) > 0 + + def test_refresh_with_priority(self, schema_any): + """Test that refresh() sets priority on new jobs.""" + table = schema.SigIntTable() + jobs = table.jobs + jobs.delete() + + jobs.refresh(priority=3) + priorities = jobs.pending.fetch("priority") + assert all(p == 3 for p in priorities) + + def test_refresh_with_delay(self, schema_any): + """Test that refresh() sets scheduled_time in the future.""" + table = schema.SigIntTable() + jobs = table.jobs + jobs.delete() + + jobs.refresh(delay=3600) # 1 hour delay + # Jobs should not be available for processing yet + keys = jobs.fetch_pending() + assert len(keys) == 0 # All jobs are scheduled for later + + def test_refresh_removes_stale_jobs(self, schema_any): + """Test that refresh() removes jobs for deleted upstream records.""" + # This test requires manipulating upstream data + pass # Skip for now + + +class TestJobsReserve: + """Tests for JobsTable.reserve() method.""" + + def test_reserve_pending_job(self, schema_any): + """Test that reserve() transitions pending -> reserved.""" + table = schema.SigIntTable() + jobs = table.jobs + jobs.delete() + jobs.refresh() + + # Get first pending job + key = jobs.pending.fetch("KEY", limit=1)[0] + jobs.reserve(key) + + # Verify status changed + status = (jobs & key).fetch1("status") + assert status == "reserved" + + def test_reserve_sets_metadata(self, schema_any): + """Test that reserve() sets user, host, pid, connection_id.""" + table = schema.SigIntTable() + jobs = table.jobs + jobs.delete() + jobs.refresh() + + key = jobs.pending.fetch("KEY", limit=1)[0] + jobs.reserve(key) + + # Verify metadata was set + row = (jobs & key).fetch1() + assert row["status"] == "reserved" + assert row["reserved_time"] is not None + assert row["user"] != "" + assert row["host"] != "" + assert row["pid"] > 0 + assert row["connection_id"] > 0 + + +class TestJobsComplete: + """Tests for JobsTable.complete() method.""" + + def test_complete_with_keep_false(self, schema_any): + """Test that complete() deletes job when keep=False.""" + table = schema.SigIntTable() + jobs = table.jobs + jobs.delete() + jobs.refresh() + + key = jobs.pending.fetch("KEY", limit=1)[0] + jobs.reserve(key) + jobs.complete(key, duration=1.5, keep=False) + + assert key not in jobs + + def test_complete_with_keep_true(self, schema_any): + """Test that complete() marks job as success when keep=True.""" + table = schema.SigIntTable() + jobs = table.jobs + jobs.delete() + jobs.refresh() + + key = jobs.pending.fetch("KEY", limit=1)[0] + jobs.reserve(key) + jobs.complete(key, duration=1.5, keep=True) + + status = (jobs & key).fetch1("status") + assert status == "success" + + +class TestJobsError: + """Tests for JobsTable.error() method.""" + + def test_error_marks_status(self, schema_any): + """Test that error() marks job as error with message.""" + table = schema.SigIntTable() + jobs = table.jobs + jobs.delete() + jobs.refresh() + + key = jobs.pending.fetch("KEY", limit=1)[0] + jobs.reserve(key) + jobs.error(key, error_message="Test error", error_stack="stack trace") + + status, msg = (jobs & key).fetch1("status", "error_message") + assert status == "error" + assert msg == "Test error" + + def test_error_truncates_long_message(self, schema_any): + """Test that error() truncates long error messages.""" + table = schema.SigIntTable() + jobs = table.jobs + jobs.delete() + jobs.refresh() - # reserve jobs - for key in subject.fetch("KEY"): - assert schema_any.jobs.reserve(table_name, key), "failed to reserve a job" + long_message = "".join(random.choice(string.ascii_letters) for _ in range(ERROR_MESSAGE_LENGTH + 100)) + + key = jobs.pending.fetch("KEY", limit=1)[0] + jobs.reserve(key) + jobs.error(key, error_message=long_message) + + msg = (jobs & key).fetch1("error_message") + assert len(msg) == ERROR_MESSAGE_LENGTH + assert msg.endswith(TRUNCATION_APPENDIX) + + +class TestJobsIgnore: + """Tests for JobsTable.ignore() method.""" + + def test_ignore_marks_status(self, schema_any): + """Test that ignore() marks job as ignore.""" + table = schema.SigIntTable() + jobs = table.jobs + jobs.delete() + jobs.refresh() + + key = jobs.pending.fetch("KEY", limit=1)[0] + jobs.ignore(key) + + status = (jobs & key).fetch1("status") + assert status == "ignore" + + def test_ignore_new_key(self, schema_any): + """Test that ignore() can create new job with ignore status.""" + table = schema.SigIntTable() + jobs = table.jobs + jobs.delete() - # refuse jobs - for key in subject.fetch("KEY"): - assert not schema_any.jobs.reserve(table_name, key), "failed to respect reservation" + # Don't refresh - ignore a key directly + key = {"id": 1} + jobs.ignore(key) - # complete jobs - for key in subject.fetch("KEY"): - schema_any.jobs.complete(table_name, key) - assert not schema_any.jobs, "failed to free jobs" + status = (jobs & key).fetch1("status") + assert status == "ignore" - # reserve jobs again - for key in subject.fetch("KEY"): - assert schema_any.jobs.reserve(table_name, key), "failed to reserve new jobs" - # finish with error - for key in subject.fetch("KEY"): - schema_any.jobs.error(table_name, key, "error message") +class TestJobsStatusProperties: + """Tests for status filter properties.""" - # refuse jobs with errors - for key in subject.fetch("KEY"): - assert not schema_any.jobs.reserve(table_name, key), "failed to ignore error jobs" + def test_pending_property(self, schema_any): + """Test that pending property returns pending jobs.""" + table = schema.SigIntTable() + jobs = table.jobs + jobs.delete() + jobs.refresh() + + assert len(jobs.pending) > 0 + statuses = jobs.pending.fetch("status") + assert all(s == "pending" for s in statuses) + + def test_reserved_property(self, schema_any): + """Test that reserved property returns reserved jobs.""" + table = schema.SigIntTable() + jobs = table.jobs + jobs.delete() + jobs.refresh() - # clear error jobs - (schema_any.jobs & dict(status="error")).delete() - assert not schema_any.jobs, "failed to clear error jobs" + key = jobs.pending.fetch("KEY", limit=1)[0] + jobs.reserve(key) + assert len(jobs.reserved) == 1 + statuses = jobs.reserved.fetch("status") + assert all(s == "reserved" for s in statuses) -def test_restrictions(clean_jobs, schema_any): - jobs = schema_any.jobs - jobs.delete() - jobs.reserve("a", {"key": "a1"}) - jobs.reserve("a", {"key": "a2"}) - jobs.reserve("b", {"key": "b1"}) - jobs.error("a", {"key": "a2"}, "error") - jobs.error("b", {"key": "b1"}, "error") + def test_errors_property(self, schema_any): + """Test that errors property returns error jobs.""" + table = schema.SigIntTable() + jobs = table.jobs + jobs.delete() + jobs.refresh() - assert len(jobs & {"table_name": "a"}) == 2 - assert len(jobs & {"status": "error"}) == 2 - assert len(jobs & {"table_name": "a", "status": "error"}) == 1 - jobs.delete() + key = jobs.pending.fetch("KEY", limit=1)[0] + jobs.reserve(key) + jobs.error(key, error_message="test") + + assert len(jobs.errors) == 1 + def test_ignored_property(self, schema_any): + """Test that ignored property returns ignored jobs.""" + table = schema.SigIntTable() + jobs = table.jobs + jobs.delete() + jobs.refresh() -def test_sigint(clean_jobs, schema_any): - try: - schema.SigIntTable().populate(reserve_jobs=True) - except KeyboardInterrupt: - pass + key = jobs.pending.fetch("KEY", limit=1)[0] + jobs.ignore(key) + + assert len(jobs.ignored) == 1 + + +class TestJobsProgress: + """Tests for JobsTable.progress() method.""" + + def test_progress_returns_counts(self, schema_any): + """Test that progress() returns status counts.""" + table = schema.SigIntTable() + jobs = table.jobs + jobs.delete() + jobs.refresh() + + progress = jobs.progress() - assert len(schema_any.jobs.fetch()), "SigInt jobs table is empty" - status, error_message = schema_any.jobs.fetch1("status", "error_message") - assert status == "error" - assert error_message == "KeyboardInterrupt" + assert "pending" in progress + assert "reserved" in progress + assert "success" in progress + assert "error" in progress + assert "ignore" in progress + assert "total" in progress + assert progress["total"] == sum(progress[k] for k in ["pending", "reserved", "success", "error", "ignore"]) + + +class TestPopulateWithJobs: + """Tests for populate() with reserve_jobs=True using new system.""" + def test_populate_creates_jobs_table(self, schema_any): + """Test that populate with reserve_jobs creates jobs table.""" + table = schema.SigIntTable() + # Clear target table to allow re-population + table.delete() + + # First populate should create jobs table + table.populate(reserve_jobs=True, suppress_errors=True, max_calls=1) + + assert table.jobs.is_declared -def test_sigterm(clean_jobs, schema_any): - try: - schema.SigTermTable().populate(reserve_jobs=True) - except SystemExit: + def test_populate_uses_jobs_queue(self, schema_any): + """Test that populate processes jobs from queue.""" + table = schema.Experiment() + table.delete() + jobs = table.jobs + jobs.delete() + + # Refresh to add jobs + jobs.refresh() + initial_pending = len(jobs.pending) + assert initial_pending > 0 + + # Populate one job + result = table.populate(reserve_jobs=True, max_calls=1) + assert result["success_count"] >= 0 # May be 0 if error + + def test_populate_with_priority_filter(self, schema_any): + """Test that populate respects priority filter.""" + table = schema.Experiment() + table.delete() + jobs = table.jobs + jobs.delete() + + # Add jobs with different priorities + # This would require the table to have multiple keys + pass # Skip for now + + +class TestSchemaJobs: + """Tests for schema.jobs property.""" + + def test_schema_jobs_returns_list(self, schema_any): + """Test that schema.jobs returns list of JobsTable objects.""" + jobs_list = schema_any.jobs + assert isinstance(jobs_list, list) + + def test_schema_jobs_contains_jobs_tables(self, schema_any): + """Test that schema.jobs contains JobsTable instances.""" + jobs_list = schema_any.jobs + for jobs in jobs_list: + assert isinstance(jobs, JobsTable) + + +class TestTableDropLifecycle: + """Tests for table drop lifecycle.""" + + def test_drop_removes_jobs_table(self, schema_any): + """Test that dropping a table also drops its jobs table.""" + # Create a temporary computed table for this test + # This test would modify the schema, so skip for now pass - assert len(schema_any.jobs.fetch()), "SigTerm jobs table is empty" - status, error_message = schema_any.jobs.fetch1("status", "error_message") - assert status == "error" - assert error_message == "SystemExit: SIGTERM received" - - -def test_suppress_dj_errors(clean_jobs, schema_any): - """test_suppress_dj_errors: dj errors suppressible w/o native py blobs""" - with dj.config.override(enable_python_native_blobs=False): - schema.ErrorClass.populate(reserve_jobs=True, suppress_errors=True) - assert len(schema.DjExceptionName()) == len(schema_any.jobs) > 0 - - -def test_long_error_message(clean_jobs, subject, schema_any): - # create long error message - long_error_message = "".join(random.choice(string.ascii_letters) for _ in range(ERROR_MESSAGE_LENGTH + 100)) - short_error_message = "".join(random.choice(string.ascii_letters) for _ in range(ERROR_MESSAGE_LENGTH // 2)) - assert subject - table_name = "fake_table" - - key = subject.fetch("KEY", limit=1)[0] - - # test long error message - schema_any.jobs.reserve(table_name, key) - schema_any.jobs.error(table_name, key, long_error_message) - error_message = schema_any.jobs.fetch1("error_message") - assert len(error_message) == ERROR_MESSAGE_LENGTH, "error message is longer than max allowed" - assert error_message.endswith(TRUNCATION_APPENDIX), "appropriate ending missing for truncated error message" - schema_any.jobs.delete() - - # test long error message - schema_any.jobs.reserve(table_name, key) - schema_any.jobs.error(table_name, key, short_error_message) - error_message = schema_any.jobs.fetch1("error_message") - assert error_message == short_error_message, "error messages do not agree" - assert not error_message.endswith(TRUNCATION_APPENDIX), "error message should not be truncated" - schema_any.jobs.delete() - - -def test_long_error_stack(clean_jobs, subject, schema_any): - # create long error stack - STACK_SIZE = 89942 # Does not fit into small blob (should be 64k, but found to be higher) - long_error_stack = "".join(random.choice(string.ascii_letters) for _ in range(STACK_SIZE)) - assert subject - table_name = "fake_table" - - key = subject.fetch("KEY", limit=1)[0] - - # test long error stack - schema_any.jobs.reserve(table_name, key) - schema_any.jobs.error(table_name, key, "error message", long_error_stack) - error_stack = schema_any.jobs.fetch1("error_stack") - assert error_stack == long_error_stack, "error stacks do not agree" + +class TestConfiguration: + """Tests for jobs configuration settings.""" + + def test_default_priority_config(self, schema_any): + """Test that config.jobs.default_priority is used.""" + original = dj.config.jobs.default_priority + try: + dj.config.jobs.default_priority = 3 + + table = schema.SigIntTable() + jobs = table.jobs + jobs.delete() + jobs.refresh() # Should use default priority from config + + priorities = jobs.pending.fetch("priority") + assert all(p == 3 for p in priorities) + finally: + dj.config.jobs.default_priority = original + + def test_keep_completed_config(self, schema_any): + """Test that config.jobs.keep_completed affects complete().""" + # Test with keep_completed=True + with dj.config.override(jobs__keep_completed=True): + table = schema.SigIntTable() + jobs = table.jobs + jobs.delete() + jobs.refresh() + + key = jobs.pending.fetch("KEY", limit=1)[0] + jobs.reserve(key) + jobs.complete(key) # Should use config + + status = (jobs & key).fetch1("status") + assert status == "success"