Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 91 additions & 127 deletions src/datajoint/staged_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@
to object storage before finalizing the database insert.
"""

import json
import mimetypes
from contextlib import contextmanager
from datetime import datetime, timezone
from typing import IO, Any
from typing import IO, TYPE_CHECKING, Any

import fsspec

from .codecs import resolve_dtype
from .errors import DataJointError
from .storage import StorageBackend, build_object_path
from .hash_registry import get_store_backend
from .storage import build_object_path

if TYPE_CHECKING:
from .storage import StorageBackend


class StagedInsert:
Expand All @@ -30,15 +33,14 @@ class StagedInsert:
staged.rec['subject_id'] = 123
staged.rec['session_id'] = 45

# Create object storage directly
# Write directly to object storage
z = zarr.open(staged.store('raw_data', '.zarr'), mode='w', shape=(1000, 1000))
z[:] = data

# Assign to record
staged.rec['raw_data'] = z

# On successful exit: metadata computed, record inserted
# On exception: storage cleaned up, no record inserted
# On clean exit: metadata is computed and the row is inserted.
# The caller does NOT assign anything to staged.rec[<object field>] —
# the framework computes the metadata dict.
# On exception: storage cleaned up, no row inserted.
"""

def __init__(self, table):
Expand All @@ -50,8 +52,7 @@ def __init__(self, table):
"""
self._table = table
self._rec: dict[str, Any] = {}
self._staged_objects: dict[str, dict] = {} # field -> {path, ext, token}
self._backend: StorageBackend | None = None
self._staged_objects: dict[str, dict] = {} # field -> {relative_path, ext, token, store_name}

@property
def rec(self) -> dict[str, Any]:
Expand All @@ -60,60 +61,57 @@ def rec(self) -> dict[str, Any]:

@property
def fs(self) -> fsspec.AbstractFileSystem:
"""Return fsspec filesystem for advanced operations."""
self._ensure_backend()
return self._backend.fs
"""
Return fsspec filesystem for the default store, for advanced operations.

def _ensure_backend(self):
"""Ensure storage backend is initialized."""
if self._backend is None:
try:
spec = self._table.connection._config.get_store_spec() # Uses stores.default
self._backend = StorageBackend(spec)
except DataJointError:
raise DataJointError(
"Storage is not configured. Set stores.default and stores.<name> settings in datajoint.json."
)

def _get_storage_path(self, field: str, ext: str = "") -> str:
For per-field access, prefer ``staged.store(field)`` or ``staged.open(field)`` —
those route to the store resolved from the field's type spec.
"""
Get or create the storage path for a field.
return self._default_backend().fs

Args:
field: Name of the object attribute
ext: Optional extension (e.g., ".zarr")
def _default_backend(self):
"""Return the StorageBackend for the default store, or raise a clear error."""
try:
return get_store_backend(None, config=self._table.connection._config)
except DataJointError:
raise DataJointError("Storage is not configured. Set stores.default and stores.<name> settings in datajoint.json.")

Returns:
Full storage path
def _resolve_field(self, field: str, ext: str) -> tuple[str, "StorageBackend"]:
"""
self._ensure_backend()
Resolve a field to its (relative_path, backend), caching on first call.

Validates the field is an ``<object@>`` attribute and that the full
primary key is set on ``staged.rec``.
"""
if field in self._staged_objects:
return self._staged_objects[field]["full_path"]
info = self._staged_objects[field]
return info["relative_path"], self._field_backend(info["store_name"])

# Validate field is an object attribute
if field not in self._table.heading:
raise DataJointError(f"Attribute '{field}' not found in table heading")

attr = self._table.heading[field]
# Check if this is an object Codec (has codec with "object" as name)
if not (attr.codec and attr.codec.name == "object"):
raise DataJointError(f"Attribute '{field}' is not an <object> type")

# Extract primary key from rec
primary_key = {k: self._rec[k] for k in self._table.primary_key if k in self._rec}
if len(primary_key) != len(self._table.primary_key):
raise DataJointError(
"Primary key values must be set in staged.rec before calling store() or open(). "
f"Missing: {set(self._table.primary_key) - set(primary_key)}"
)

# Get storage spec (uses stores.default)
spec = self._table.connection._config.get_store_spec()
# Resolve the store name from the field's type spec (e.g., <object@local> -> "local")
_, _, store_name = resolve_dtype(f"<{attr.codec.name}>", store_name=attr.store)

config = self._table.connection._config
try:
spec = config.get_store_spec(store_name)
except DataJointError:
raise DataJointError("Storage is not configured. Set stores.default and stores.<name> settings in datajoint.json.")
partition_pattern = spec.get("partition_pattern")
token_length = spec.get("token_length", 8)

# Build storage path (relative - StorageBackend will add location prefix)
relative_path, token = build_object_path(
schema=self._table.database,
table=self._table.class_name,
Expand All @@ -124,18 +122,25 @@ def _get_storage_path(self, field: str, ext: str = "") -> str:
token_length=token_length,
)

# Store staged object info (all paths are relative, backend adds location)
self._staged_objects[field] = {
"relative_path": relative_path,
"ext": ext if ext else None,
"token": token,
"store_name": store_name,
}

return relative_path
return relative_path, self._field_backend(store_name)

def _field_backend(self, store_name: str | None):
"""Return the StorageBackend for the named store."""
try:
return get_store_backend(store_name, config=self._table.connection._config)
except DataJointError:
raise DataJointError("Storage is not configured. Set stores.default and stores.<name> settings in datajoint.json.")

def store(self, field: str, ext: str = "") -> fsspec.FSMap:
"""
Get an FSMap store for direct writes to an object field.
Get an FSMap for direct writes to an ``<object@>`` field.

Args:
field: Name of the object attribute
Expand All @@ -144,12 +149,12 @@ def store(self, field: str, ext: str = "") -> fsspec.FSMap:
Returns:
fsspec.FSMap suitable for Zarr/xarray
"""
path = self._get_storage_path(field, ext)
return self._backend.get_fsmap(path)
relative_path, backend = self._resolve_field(field, ext)
return backend.get_fsmap(relative_path)

def open(self, field: str, ext: str = "", mode: str = "wb") -> IO:
"""
Open a file for direct writes to an object field.
Open a file for direct writes to an ``<object@>`` field.

Args:
field: Name of the object attribute
Expand All @@ -159,127 +164,86 @@ def open(self, field: str, ext: str = "", mode: str = "wb") -> IO:
Returns:
File-like object for writing
"""
path = self._get_storage_path(field, ext)
return self._backend.open(path, mode)
relative_path, backend = self._resolve_field(field, ext)
return backend.open(relative_path, mode)

def _compute_metadata(self, field: str) -> dict:
"""
Compute metadata for a staged object after writing is complete.
Compute the canonical ``<object@>`` metadata dict for a staged write.

Args:
field: Name of the object attribute
The returned dict is structurally equal to what ``ObjectCodec.encode``
would produce for the same content, modulo ``timestamp``.

Returns:
JSON-serializable metadata dict
Returns
-------
dict
``{path, store, size, ext, is_dir, item_count, timestamp}``
"""
info = self._staged_objects[field]
relative_path = info["relative_path"]
ext = info["ext"]
store_name = info["store_name"]
backend = self._field_backend(store_name)

# Check if it's a directory (multiple files) or single file
# _full_path adds the location prefix
full_remote_path = self._backend._full_path(relative_path)
full_remote_path = backend._full_path(relative_path)

try:
is_dir = self._backend.fs.isdir(full_remote_path)
is_dir = backend.fs.isdir(full_remote_path)
except Exception:
is_dir = False

if is_dir:
# Calculate total size and file count
total_size = 0
item_count = 0
files = []

for root, dirs, filenames in self._backend.fs.walk(full_remote_path):
for root, _dirs, filenames in backend.fs.walk(full_remote_path):
for filename in filenames:
file_path = f"{root}/{filename}"
try:
file_size = self._backend.fs.size(file_path)
rel_path = file_path[len(full_remote_path) :].lstrip("/")
files.append({"path": rel_path, "size": file_size})
total_size += file_size
total_size += backend.fs.size(f"{root}/{filename}")
item_count += 1
except Exception:
pass

# Create manifest
manifest = {
"files": files,
"total_size": total_size,
"item_count": item_count,
"created": datetime.now(timezone.utc).isoformat(),
}

# Write manifest alongside folder
manifest_path = f"{relative_path}.manifest.json"
self._backend.put_buffer(json.dumps(manifest, indent=2).encode(), manifest_path)

metadata = {
"path": relative_path,
"size": total_size,
"hash": None,
"ext": ext,
"is_dir": True,
"timestamp": datetime.now(timezone.utc).isoformat(),
"item_count": item_count,
}
size = total_size
else:
# Single file
try:
size = self._backend.size(relative_path)
size = backend.size(relative_path)
except Exception:
size = 0

metadata = {
"path": relative_path,
"size": size,
"hash": None,
"ext": ext,
"is_dir": False,
"timestamp": datetime.now(timezone.utc).isoformat(),
}

# Add mime_type for files
if ext:
mime_type, _ = mimetypes.guess_type(f"file{ext}")
if mime_type:
metadata["mime_type"] = mime_type

return metadata
item_count = None

return {
"path": relative_path,
"store": store_name,
"size": size,
"ext": ext,
"is_dir": is_dir,
"item_count": item_count,
"timestamp": datetime.now(timezone.utc).isoformat(),
}

def _finalize(self):
"""
Finalize the staged insert by computing metadata and inserting the record.
Compute metadata for each staged object and insert the row.
"""
# Process each staged object
for field in list(self._staged_objects.keys()):
metadata = self._compute_metadata(field)
# Store metadata dict in the record (ObjectType.encode handles it)
self._rec[field] = metadata

# Insert the record
self._rec[field] = self._compute_metadata(field)
self._table.insert1(self._rec)

def _cleanup(self):
"""
Clean up staged objects on failure.
Best-effort removal of staged objects on failure.
"""
if self._backend is None:
return

for field, info in self._staged_objects.items():
relative_path = info["relative_path"]
try:
# Check if it's a directory
full_remote_path = self._backend._full_path(relative_path)
if self._backend.fs.exists(full_remote_path):
if self._backend.fs.isdir(full_remote_path):
self._backend.remove_folder(relative_path)
backend = self._field_backend(info["store_name"])
full_remote_path = backend._full_path(relative_path)
if backend.fs.exists(full_remote_path):
if backend.fs.isdir(full_remote_path):
backend.remove_folder(relative_path)
else:
self._backend.remove(relative_path)
backend.remove(relative_path)
except Exception:
pass # Best effort cleanup
pass # Best-effort cleanup


@contextmanager
Expand All @@ -299,7 +263,7 @@ def staged_insert1(table):
staged.rec['session_id'] = 45
z = zarr.open(staged.store('raw_data', '.zarr'), mode='w')
z[:] = data
staged.rec['raw_data'] = z
# Metadata for 'raw_data' is computed on clean exit; do not assign it here.
"""
staged = StagedInsert(table)
try:
Expand Down
Loading
Loading