"""
Common DB operations for SDP Configuration Database entities.
"""
import enum
import functools
import json
import logging
import re
from typing import Generic, Iterator, TypeVar
from ..base_transaction import BaseTransaction
from ..entity.base import (
EntityBaseModel,
EntityKeyBaseModel,
MultiEntityBaseModel,
)
from ..entity.owner import Owner
logger = logging.getLogger(__name__)
OWNER_PATH = "owner"
STATE_PATH = "state"
[docs]
class PathOperations:
"""Operations performed on any path."""
def __init__(self, txn: BaseTransaction, path: str):
self._txn = txn
self._path = path
@property
def path(self) -> str:
"""The full path in the SDP Configuration Database."""
return self._path
[docs]
def get(self) -> dict | None:
"""Return the value under this path."""
return self._txn.get(self._path)
[docs]
def exists(self) -> bool:
"""Whether there's anything stored under this path."""
return self._txn.exists(self._path)
[docs]
class StateOperations(PathOperations):
"""Operations performed on an entity's state."""
[docs]
def create(self, state: dict) -> None:
"""Creates the state of the entity."""
self._txn.create(self.path, state)
[docs]
def update(self, state: dict) -> None:
"""Updates the state of the entity."""
self._txn.update(self.path, state)
[docs]
class OwnershipOperations(PathOperations):
"""Operations performed on an entity's ownership."""
[docs]
def take(self) -> None:
"""Takes ownership of the entity."""
self._txn.create(
self.path, self._txn.owner.model_dump(), self._txn.lease
)
[docs]
def is_owned_by_this_process(self) -> bool:
"""Whether this process owns the entity."""
return self.get() == self._txn.owner
[docs]
def is_owned(self) -> bool:
"""Whether the entity is owned by anyone."""
return self.exists()
[docs]
def get(self) -> Owner | None:
"""Return the owner value"""
raw = super().get()
if raw is None:
return None
return Owner(**raw)
# Type annotations for supported entity model classes, we currently support
# both pydantic models and plain dicts
Model = dict | EntityBaseModel
ModelKey = str | EntityKeyBaseModel
ModelT = TypeVar("ModelT", bound=Model)
ModelKeyT = TypeVar("ModelKeyT", bound=ModelKey)
[docs]
class EntityOperations(PathOperations, Generic[ModelT]):
"""
Base class defining common operations that can be performed on a entity.
"""
MODEL_CLASS: type[Model] = dict
"""The model class for this entity."""
def __init__(
self,
txn: BaseTransaction,
path: str,
key: str | None = None,
key_parts: dict | None = None,
):
super().__init__(txn, path)
self._key = key
self._key_parts = key_parts or {}
@classmethod
def _to_dict(cls, value: ModelT) -> dict:
if not isinstance(value, cls.MODEL_CLASS):
raise ValueError(
"Value given for writing is not an instance "
f"of {cls.MODEL_CLASS.__name__}"
)
if cls.MODEL_CLASS == dict:
return value
return value.model_dump(mode="json")
def _from_dict(self, value: dict | None) -> ModelT | None:
if value is None:
return None
if issubclass(self.MODEL_CLASS, dict):
return value
model_value = {}
if issubclass(self.MODEL_CLASS, MultiEntityBaseModel):
# The full model contains not only the data in the DB document,
# but in cases also the keys used to retrieve it. But sometimes
# those keys are also present in the document itself, which is
# given priority
if issubclass(self.MODEL_CLASS.Key, str):
assert len(self._key_parts) == 1
key = next(iter(self._key_parts))
else:
key = json.loads(
self.MODEL_CLASS.Key(**self._key_parts).model_dump_json()
)
model_value = {"key": key}
model_value.update(value)
# pylint: disable-next=no-member
return self.MODEL_CLASS.model_validate_json( # type: ignore
json.dumps(model_value)
)
@property
def key(self) -> str | None:
"""
For entity types with multiple entries, the key to this individual
entity. `None` for entity types with single multiplicity.
"""
return self._key
@property
def key_parts(self) -> dict:
"""
For entity types with multiple entries, the individual parts that make
up the entity's key. Empty for entity types with single multiplicity.
"""
return self._key_parts
[docs]
def create(self, value: ModelT) -> None:
"""Creates the entity."""
self._txn.create(self.path, self._to_dict(value))
[docs]
def update(self, value: ModelT) -> None:
"""Updates the entity."""
self._txn.update(self.path, self._to_dict(value))
[docs]
def create_or_update(self, value: ModelT) -> None:
"""Creates or updates the entity."""
self._txn.create_or_update(self.path, self._to_dict(value))
[docs]
def delete(self, recurse=False) -> None:
"""Deletes the entity."""
self._txn.delete(self.path, recurse=recurse)
[docs]
def get(self) -> ModelT | None:
"""Reads the entity."""
return self._from_dict(super().get())
def __str__(self) -> str:
return f'<EntityOperations path="{self.path}">'
# pylint: disable=no-member,too-few-public-methods
[docs]
class StatefulEntityOperationsMixIn:
"""
Mix-in class to derive from to get access to state-related operations.
"""
@property
def state(self) -> StateOperations:
"""Get a set of state operations for a single entity."""
return StateOperations(self._txn, f"{self.path}/{STATE_PATH}")
# pylint: disable=no-member,too-few-public-methods
[docs]
class OwnedEntityOperationsMixIn:
"""
Mix-in class to derive from to get access to ownership-related operations.
"""
@property
def ownership(self) -> OwnershipOperations:
"""Get a set of ownership operations for a single entity."""
return OwnershipOperations(self._txn, f"{self.path}/{OWNER_PATH}")
[docs]
def is_alive(self) -> bool:
"""Whether this entity is alive, based on its ownership entry."""
return self.ownership.is_owned()
[docs]
def take_ownership_if_not_alive(self) -> None:
"""Takes ownership of this entity if it's not currently alive."""
ownership = self.ownership
if not ownership.is_owned():
logger.info("Owner entry not present, taking ownership")
ownership.take()
[docs]
class InvalidKey(RuntimeError):
"""Raised when an invalid key is given, either directly or indirectly."""
def __init__(self, key: str, pattern: re.Pattern):
self.key = key
self.pattern = pattern
super().__init__(f"Invalid {key=} for {pattern=}")
[docs]
class ComparisonMode(enum.Enum):
"""Ways to compare two strings (needle and haystack)."""
SUFFIX = enum.auto()
"""The needle appears at the end of the haystack"""
PREFIX = enum.auto()
"""The needle appears at the start of the haystack"""
EQ = enum.auto()
"""The needle is equals to the haystack"""
def _compare(haystack: str, needle: str, mode: ComparisonMode) -> bool:
match mode:
case ComparisonMode.SUFFIX:
return haystack.endswith(needle)
case ComparisonMode.PREFIX:
return haystack.startswith(needle)
case ComparisonMode.EQ:
return haystack == needle
def _make_decorator(condition, error_msg):
def decorator(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
if condition(self):
raise NotImplementedError(error_msg)
return func(self, *args, **kwargs)
return wrapper
return decorator
pydantic_only = _make_decorator(
lambda self: self.MODEL_CLASS == dict,
"Only available for entities with a pydantic model",
)
stateful_only = _make_decorator(
lambda self: not self.HAS_STATE,
"Only available for entities that have state",
)
owned_only = _make_decorator(
lambda self: not self.HAS_OWNER,
"Only available for entities that have an owner",
)
[docs]
class CollectiveEntityOperations(Generic[ModelT, ModelKeyT]):
"""
Base class defining common operations that can be performed on an entity
type that can have multiple of its entities stored in the SDP Config
database, all of which are stored under a common prefix. Subclasses need to
provide the entities' prefix and key patterns. The latter are used to
verify user invocations and key contents, and build the full key (and thus
the SDP config DB path) for a particular entity.
"""
PREFIX: str
"""
The prefix under which entities are stored in the SDP Configuration DB.
"""
KEY_PARTS: dict[str, str]
"""
The ``(name, pattern)`` string pairs describing how keys are built for
elements belonging to this entity. The order in which elements are given
affects how full key are generated.
"""
MODEL_CLASS: type[Model] = dict
"""
The model class for this entity. If given, it must derive from
MultiEntityBaseModel, and implement the key_patterns() class method.
"""
HAS_STATE: bool = False
"""Whether entities of this type have state."""
HAS_OWNER: bool = False
"""Whether entities of this type have an owner."""
# Created dynamically from KEY_PARTS
_KEY_PART_PATTERNS: dict[str, re.Pattern]
_KEY_PATTERN: re.Pattern
# Created dynamically from our own base classes
_ENTITY_OPERATIONS_CLASS: type[EntityOperations]
def __new__(cls, *_args, **_kwargs):
if not hasattr(cls, "_KEY_PART_PATTERNS"):
if cls.MODEL_CLASS != dict:
assert issubclass(
cls.MODEL_CLASS, MultiEntityBaseModel
), f"{cls.__name__}.MODEL_CLASS must be a MultiEntityBaseModel"
cls.KEY_PARTS = dict(cls.MODEL_CLASS.key_patterns())
cls._KEY_PART_PATTERNS = {
name: re.compile(f"^{pattern}$")
for name, pattern in cls.KEY_PARTS.items()
}
cls._KEY_PATTERN = re.compile(
"^"
+ ":".join(
f"(?P<{name}>{pattern})"
for name, pattern in cls.KEY_PARTS.items()
)
+ "$"
)
bases = (EntityOperations,)
if cls.HAS_OWNER:
bases += (OwnedEntityOperationsMixIn,)
if cls.HAS_STATE:
bases += (StatefulEntityOperationsMixIn,)
cls._ENTITY_OPERATIONS_CLASS = type(
"_EntityOperations",
bases,
{
"__str__": EntityOperations.__str__,
"__repr__": EntityOperations.__str__,
"MODEL_CLASS": cls.MODEL_CLASS,
},
)
return super().__new__(cls)
@classmethod
def _validate_key_part(cls, name: str, value: str) -> None:
pattern = cls._KEY_PART_PATTERNS.get(name)
if pattern is None:
raise ValueError(
f'"{name}" is not a valid key part under {cls.PREFIX}'
)
if not pattern.search(value):
raise ValueError(
f'"{value}" is not a valid value for key part "{name}" '
f"under {cls.PREFIX}"
)
@classmethod
def _create_key(cls, **key_parts) -> str:
validated_key_parts = {}
for name, value in key_parts.items():
cls._validate_key_part(name, value)
validated_key_parts[name] = value
return ":".join(validated_key_parts[name] for name in cls.KEY_PARTS)
@classmethod
def _key_matches(cls, result: re.Match, **key_matches_kwargs) -> bool:
for name, value in key_matches_kwargs.items():
if name.endswith("_prefix"):
name = name[: -len("_prefix")]
mode = ComparisonMode.PREFIX
elif name.endswith("_suffix"):
name = name[: -len("_suffix")]
mode = ComparisonMode.SUFFIX
else:
mode = ComparisonMode.EQ
if name not in cls.KEY_PARTS:
raise ValueError(f"{name} is not valid under {cls.PREFIX}")
if not value:
continue
if not _compare(result[name], value, mode):
return False
return True
@classmethod
def _is_valid_key(cls, key: str) -> bool:
return bool(cls._KEY_PATTERN.search(key))
@classmethod
def _get_key_and_path(cls, *key_args, **key_kwargs):
if key_args and key_kwargs:
raise ValueError(
"Can't use positional and keyword arguments simultaneously"
)
if key_args:
if len(key_args) != 1:
raise ValueError(
"Only single positional argument can be given"
)
if len(cls._KEY_PART_PATTERNS) != 1:
raise ValueError(
"Positional arguments unsupported for multi-part keys"
)
key_part_name = next(iter(cls.KEY_PARTS))
cls._validate_key_part(key_part_name, key_args[0])
key = key_args[0]
key_parts = {key_part_name: key}
assert cls._is_valid_key(key)
elif "key" in key_kwargs:
if len(key_kwargs) > 1:
raise ValueError(
"'key' cannot be combined with other keyword arguments"
)
key = key_kwargs["key"]
if not cls._is_valid_key(key):
raise InvalidKey(key, cls._KEY_PATTERN)
key_match = cls._KEY_PATTERN.search(key)
assert bool(key_match)
key_parts = key_match.groupdict()
else:
key = cls._create_key(**key_kwargs)
assert cls._is_valid_key(key)
key_parts = dict(key_kwargs)
return f"{cls.PREFIX}/{key}", key, key_parts
def __init__(self, txn: BaseTransaction):
self._txn = txn
[docs]
def __call__(self, *args, **kwargs) -> EntityOperations[ModelT]:
"""
A convenience method that internally forwards all arguments to
:meth:`index_by_key_parts`. Available only to entities modelled as
plain dictionaries. Entities modelled via pydantic should use the
higher level methods offered by this class instead.
"""
if self.MODEL_CLASS != dict:
raise NotImplementedError(
"Only available for entities modelled as plain dictionaries"
)
return self.index_by_key_parts(*args, **kwargs)
[docs]
def index_by_key_parts(self, *key_args, **key_kwargs):
"""
Return operations over a single entity.
Individual entities are accessed by specifying their *key*.
This entity key can be given in the following ways:
* Via a ``key`` keyword argument, in which case it's taken verbatim
and its constituent parts are extracted.
* Via multiple keyword arguments that make up the full key.
The keyword argument names must correspond to the key part names.
* If the key has a single part, it can be given as a positional
argument for ease of use.
"""
path, key, key_parts = self._get_key_and_path(*key_args, **key_kwargs)
return self._ENTITY_OPERATIONS_CLASS(self._txn, path, key, key_parts)
[docs]
def query_keys(self, **key_matches_kwargs) -> Iterator[ModelKeyT]:
"""
Iterate over the keys matching the given constrains.
Constrains are given via keyword arguments, where each keyword argument
name corresponds to a key part of this entity type, and the argument
value corresponds to the value the key part should have in the database
to match the query. If a keyword argument name doesn't match any key
part name an :class:`InvalidKey` error is raised.
Keyword argument names can be suffixed with ``_prefix`` or ``_suffix``,
in which case matching is done not exactly, but by prefix or suffix
respectively.
Invalid keys in the database are ignored.
"""
for key in self._txn.list_keys(self.PREFIX + "/"):
result = self._KEY_PATTERN.search(key)
if not result:
logger.warning(
"Key %s doesn't have expected format, skipping", key
)
continue
if not self._key_matches(result, **key_matches_kwargs):
continue
if self.MODEL_CLASS != dict and issubclass(
self.MODEL_CLASS.Key, EntityKeyBaseModel
):
key_parts = key.split(":")
part_names = self.MODEL_CLASS.Key.model_fields
yield self.MODEL_CLASS.Key(**dict(zip(part_names, key_parts)))
else:
yield key
[docs]
def list_keys(self, **key_matches_kwargs) -> list[ModelKeyT]:
"""Like :meth:`query_keys`, but returns a list."""
return list(self.query_keys(**key_matches_kwargs))
[docs]
def query_values(
self, **key_matches_kwargs
) -> Iterator[tuple[ModelKeyT, ModelT]]:
"""Like :meth:`query_keys`, but yields ``(key, value)`` pairs."""
for key in self.query_keys(**key_matches_kwargs):
if self.MODEL_CLASS != dict and issubclass(
self.MODEL_CLASS.Key, EntityKeyBaseModel
):
key_parts = key.model_dump()
else:
key_parts = {"key": key}
yield key, self.index_by_key_parts(**key_parts).get()
[docs]
def list_values(
self, **key_matches_kwargs
) -> list[tuple[ModelKeyT, dict]]:
"""Like :meth:`query_values`, but returns a list."""
return list(self.query_values(**key_matches_kwargs))
def _index_by_key_or_value(
self, index: ModelT | ModelKeyT
) -> EntityOperations[ModelT]:
if isinstance(index, self.MODEL_CLASS):
return self._index_by_key(index.key)
if isinstance(index, self.MODEL_CLASS.Key):
return self._index_by_key(index)
cls_name = self.MODEL_CLASS.__name__
key_cls_name = self.MODEL_CLASS.Key.__name__
raise ValueError(
f"{index} is not an instance of {cls_name} or {key_cls_name}"
)
def _index_by_value(self, index: ModelT) -> EntityOperations[ModelT]:
if isinstance(index, self.MODEL_CLASS):
return self._index_by_key(index.key)
cls_name = self.MODEL_CLASS.__name__
raise ValueError(f"{index} is not an instance of {cls_name}")
def _index_by_key(self, index: ModelKeyT) -> EntityOperations[ModelT]:
assert isinstance(index, self.MODEL_CLASS.Key)
if self.MODEL_CLASS.Key != str:
key_kwargs = index.model_dump()
else:
key_kwargs = {"key": index}
return self.index_by_key_parts(**key_kwargs)
[docs]
@pydantic_only
def path(self, index: ModelT | ModelKeyT) -> str:
"""
The full path in the SDP Configuration Database for the given entity.
"""
return self._index_by_key_or_value(index).path
[docs]
@pydantic_only
def get(self, index: ModelT | ModelKeyT) -> ModelT:
"""
Get an entry from the database. See :meth:`__call__` for details on how
to specify the entity to be read. Available only for entities defined
as pydantic models..
"""
return self._index_by_key_or_value(index).get()
[docs]
@pydantic_only
def exists(self, index: ModelT | ModelKeyT) -> ModelT:
"""
Check if the entity exists. See :meth:`__call__` for details on how to
specify the entity to be read. Available only for entities defined as
pydantic models..
"""
return self._index_by_key_or_value(index).exists()
[docs]
@pydantic_only
@stateful_only
def state(self, index: ModelT | ModelKeyT) -> StateOperations:
"""
Get the operations on an entity's state. Available only for entities
defined as pydantic models that declare a state.
"""
return self._index_by_key_or_value(index).state
[docs]
@pydantic_only
@owned_only
def ownership(self, index: ModelT | ModelKeyT) -> OwnershipOperations:
"""
Get the operations on entity's ownership. Available only for entities
defined as pydantic models declaring ownership.
"""
return self._index_by_key_or_value(index).ownership
[docs]
@pydantic_only
@owned_only
def is_alive(self, index: ModelT | ModelKeyT) -> bool:
"""
Whether this entity is avlie. Available only for entities
defined as pydantic models declaring ownership.
"""
return self._index_by_key_or_value(index).is_alive()
[docs]
@pydantic_only
@owned_only
def take_ownership_if_not_alive(self, index: ModelT | ModelKeyT) -> None:
"""
Takes ownership of this entity if it's not currently alive. Available
only for entities defined as pydantic models declaring ownership.
"""
return self._index_by_key_or_value(index).take_ownership_if_not_alive()
[docs]
@pydantic_only
def create(self, value: ModelT) -> None:
"""
Create an entry in the database for this entity. Available only for
entities defined as pydantic models.
"""
self._index_by_value(value).create(value)
[docs]
@pydantic_only
def update(self, value: ModelT) -> None:
"""
Updates an entry in the database with the given entity's contents.
Available only for entities defined as pydantic models.
"""
self._index_by_value(value).update(value)
[docs]
@pydantic_only
def create_or_update(self, value: ModelT) -> None:
"""
Creates or updates an entry in the database with the given entity's
contents. Available only for entities defined as pydantic models.
"""
self._index_by_value(value).create_or_update(value)
[docs]
@pydantic_only
def delete(self, index: ModelT | ModelKeyT, recurse: bool = False) -> None:
"""
Deletes this entry from the database. Note that only the key fields
from the entity are used to determine the database entry, the rest of
the contents don't need to match those in the database. Available only
for entities defined as pydantic model.
"""
self._index_by_key_or_value(index).delete(recurse=recurse)