"""Etcd3 backend for SKA SDP configuration DB."""
# pylint: disable=duplicate-code
from __future__ import annotations
import logging
import os
import time
from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, cast
import etcd3
import requests
import semantic_version
from .backend import Backend, DbRevision, DbTransaction, Lease, RecurseType
from .common import (
ConfigCollision,
ConfigVanished,
_check_path,
_tag_depth,
_untag_depth,
depth_of_path,
)
if TYPE_CHECKING:
from .etcd3_watcher import Etcd3Watcher
LOGGER = logging.getLogger(__name__)
# Change the log level for the imported package 'etcd3'
# and the dependent package 'urllib3'
lowerLevelLog = os.getenv("SDP_CONFIG_ETCD3_LOG_LEVEL", "INFO")
logging.getLogger("etcd3").setLevel(lowerLevelLog)
logging.getLogger("urllib3").setLevel(lowerLevelLog)
[docs]
class Etcd3Backend(Backend):
"""
Highly consistent database backend store.
See https://github.com/kragniz/python-etcd3
"""
def __init__(
self,
host="127.0.0.1",
port="2379",
max_retries: int = 15,
retry_time: float = 0.1,
**kw_args,
):
"""
Instantiate the database client.
"""
self._max_retries = max_retries
self._retry_time = retry_time
# Make endpoint for (presumably singular) host with retry timeout of 10
# seconds - that is the fastest that a gRPC connection can exit the
# UNAVAILABLE state apparently.
endpoint = etcd3.Endpoint(
host,
port,
secure=kw_args.get("uses_secure_channel"),
creds=kw_args.get("creds"),
opts=kw_args.get("grpc_options"),
time_retry=10,
)
# Create "multi-endpoint" client with failover so that we get
# NoServerAvailableError raised.
self._client = etcd3.MultiEndpointEtcd3Client(
endpoints=[endpoint], failover=True, **kw_args
)
self._verify_server_version()
def _verify_server_version(self):
"""
Verify that etcd server release is new enough to guarantee
correct packet order for progress notififcations
"""
# Get version via HTTP
endpoint = self._client.endpoint_in_use
response = requests.get(
f"{endpoint.protocol}://{endpoint.netloc}/version", timeout=0.3
) # 300ms will do?
response.raise_for_status()
# Progress notifications are handled correctly from 3.4.26 and
# 3.5.8 forward. We also allow 3.6.0 prerelease so development
# builds work.
ver = semantic_version.Version(
response.json().get("etcdserver", "0.0.0")
)
LOGGER.debug("Detected etcd server version %s", ver)
spec = semantic_version.NpmSpec(
">=3.4.26 <3.5 || >=3.5.8 || >=3.6.0-a"
)
if ver not in spec:
raise RuntimeError(
f"Etcd Server version is {ver}, need 3.4.26 or 3.5.8 for "
"watcher to work correctly!"
)
def _retry_loop(self, code_to_try: Callable) -> Any:
"""
Helper that retries code if an exception gets thrown that
typically indicates a loss of connection. Note that this *can*
rarely mean that the effect of the code in question was
executed multiple times.
"""
# Retry loop
retry_time = self._retry_time
for i in range(self._max_retries):
# Common retry code
def log_exception(ex, i):
LOGGER.warning(
"Caught %s, retry %d after %gs",
repr(ex),
i,
retry_time,
)
# Run the code, catching typical exceptions
try:
return code_to_try()
except etcd3.exceptions.ConnectionFailedError as ex:
log_exception(ex, i)
except etcd3.exceptions.ConnectionTimeoutError as ex:
log_exception(ex, i)
except etcd3.exceptions.NoServerAvailableError as ex:
log_exception(ex, i)
# Delay before next iteration
time.sleep(retry_time)
retry_time *= 1.5 # back off
# Attempt one final time - without safety net
return code_to_try()
[docs]
def get(
self, path: str, revision: Optional[DbRevision] = None
) -> tuple[str, DbRevision]:
# Check/prepare parameters
_check_path(path)
tagged_path = _tag_depth(path)
rev = None if revision is None else revision.revision
# Get value and revision
range_response = self._retry_loop(
lambda: self._client.get_response(tagged_path, revision=rev)
)
# handle non-existence of key
if range_response.count < 1:
value = None
else:
popped_kv_pair = range_response.kvs.pop()
value = popped_kv_pair.value.decode("utf-8")
# set revision whether key exists or not
revision = range_response.header.revision
rev = DbRevision(revision)
return value, rev
[docs]
def create(
self, path: str, value: str, lease: Optional[etcd3.Lease] = None
) -> None:
# Prepare parameters
_check_path(path)
tagged_path = _tag_depth(path)
lease_id = 0 if lease is None else lease.id
value = str(value)
response = self._retry_loop(
lambda: self._client.put_if_not_exists(
tagged_path, value, lease_id
)
)
if not response:
raise ConfigCollision(
path, f"Cannot create {path}, as it already exists!"
)
[docs]
def update(
self,
path: str,
value: str,
) -> None:
# Validate parameters
_check_path(path)
tagged_path = _tag_depth(path)
value = str(value)
# Execute in a transaction.
# Supported operators are equality/less/greater (not boolean).
status, _ = self._retry_loop(
lambda: self._client.transaction(
compare=[self._client.transactions.version(tagged_path) > 0],
success=[self._client.transactions.put(tagged_path, value)],
failure=[],
)
)
if not status:
raise ConfigVanished(
path, f"Cannot update {path}, as it does not exist!"
)
# pylint: disable=cell-var-from-loop
[docs]
def list_keys(
self,
path: str,
recurse: RecurseType = 0,
revision: Optional[DbRevision] = None,
with_values: bool = False,
) -> tuple[list[str], DbRevision]:
"""
List keys under given path.
:param path: Prefix of keys to query. Append '/' to list
child paths.
:param recurse: Maximum recursion level to query. If iterable,
cover exactly the recursion levels specified.
:param revision: Database revision for which to list
:param with_values: Also return key values and mod revisions
(i.e. sorted list of key-value-rev tuples)
:returns: (sorted key list, DbRevision object)
"""
# Prepare parameters
path_depth = depth_of_path(path)
rev = None
keys = []
if with_values:
vals = []
revs = []
if revision is not None:
rev = revision.revision
if isinstance(recurse, Iterable):
depth_iter = iter(recurse)
else:
depth_iter = range(recurse + 1)
for depth in depth_iter:
tagged_path = _tag_depth(path, depth + path_depth)
range_response = self._retry_loop(
lambda: self._client.get_prefix_response(
tagged_path, revision=rev, keys_only=not with_values
)
)
if rev is None:
rev = range_response.header.revision
for kv_pair in range_response.kvs:
keys.append(_untag_depth(kv_pair.key))
if with_values:
vals.append(kv_pair.value)
revs.append(DbRevision(rev))
revision = DbRevision(rev)
if range_response is None:
return [], revision
if with_values:
return (
sorted(zip(keys, vals, revs), key=lambda kv: kv[0]),
revision,
)
return sorted(keys), revision
[docs]
def lease(self, ttl: float = 10) -> Lease:
"""
Generate a new lease.
Once entered, it can be associated with keys which will be kept
alive until the end of the lease.
Note that this involves starting a daemon thread that will refresh
the lease periodically (default seems to be TTL/4).
:param ttl: Time to live for lease
:return: lease object
"""
return self._retry_loop(
lambda: cast(Lease, self._client.lease(ttl=ttl))
)
[docs]
def txn(self, max_retries: int = 64) -> Iterable["Etcd3Transaction"]:
yield from Etcd3Transaction(self, self._client, max_retries)
[docs]
def watcher(
self,
timeout=None,
txn_wrapper: Callable[["Etcd3Transaction"], object] = None,
requery_progress: float = 0.2,
) -> Iterable[Etcd3Watcher]:
"""Create a new watcher.
Useful for waiting for changes in the configuration. See
:py:class:`.etcd3_watcher.Etcd3Watcher`.
:param timeout: Timeout for waiting. Watcher will loop after this time.
:param txn_wrapper: Function to wrap transactions returned by the
wrapper.
:param requery_progress: How often we "refresh" the current
database state for watcher transactions even without
watcher notification (upper bound on how "stale"
non-watched values retrieved in transactions can be)
:returns: Watcher iterator
"""
# To get around cyclic imports
# pylint: disable=import-outside-toplevel
from .etcd3_watcher import Etcd3Watcher
return Etcd3Watcher(
self, self._client, timeout, txn_wrapper, requery_progress
)
# pylint: disable=cell-var-from-loop
def _delete_recursive(
self,
path: str,
must_exist: bool = True,
prefix: bool = False,
max_depth: int = 16,
):
# Factored out from delete due to too high cognitive complexity.
depth = depth_of_path(path)
delete_count = 0
for level in range(depth + 1, depth + max_depth):
dpath = _tag_depth(path if prefix else path + "/", level)
prefix_response = self._retry_loop(
lambda: self._client.delete_prefix(dpath)
)
if prefix_response:
delete_count += 1
response = delete_count >= 1
if not response and must_exist:
raise ConfigVanished(
path, f"Cannot delete {path}, as it does not exist!"
)
# pylint: disable=too-many-arguments
[docs]
def delete(
self,
path: str,
must_exist: bool = True,
recursive: bool = False,
prefix: bool = False,
max_depth: int = 16,
):
# Prepare parameters
tagged_path = _tag_depth(path)
if prefix:
prefix_response = self._retry_loop(
lambda: self._client.delete_prefix(tagged_path)
)
response = prefix_response.deleted >= 1
else:
response = self._retry_loop(
lambda: self._client.delete(tagged_path)
)
if not response and must_exist:
raise ConfigVanished(
path, f"Cannot delete {path}, as it does not exist!"
)
if recursive:
self._delete_recursive(path, must_exist, prefix, max_depth)
[docs]
def close(self) -> None:
self._client.close()
# pylint: disable=too-many-instance-attributes
[docs]
class Etcd3Transaction(DbTransaction):
"""
A series of queries and updates to be executed atomically.
"""
def __init__(
self,
backend: Etcd3Backend,
client: etcd3.client,
max_retries: int = 64,
):
"""Initialise transaction."""
super().__init__(backend)
self._client = client
self._max_retries = max_retries
self._revision = None # Revision baked in after first read
self._get_queries: dict[str, tuple[str, DbRevision]] = (
{}
) # Get query log
self._updates: dict[str, tuple[Optional[str], Optional[Lease]]] = (
{}
) # Delayed updates
self._list_queries: dict[
tuple[str, int], tuple[list[str], DbRevision]
] = {}
self._committed = False
self._retries = 0
self._commit_callbacks: list[Callable[[], None]] = []
@property
def revision(self) -> int:
"""The last-committed database revision.
Only valid to call after the transaction has been committed.
:returns: revision from DbRevision
"""
if not self._committed:
raise RuntimeError(
"Revision is undefined on an uncommitted transaction!"
)
return self._revision.revision
def _ensure_uncommitted(self) -> None:
if self._committed:
raise RuntimeError("Attempted to modify committed transaction!")
# pylint: disable=duplicate-code
[docs]
def get(self, path: str) -> Optional[str]:
"""
Get value of a key.
:param path: Path of key to query
:returns: Key value. None if it doesn't exist.
"""
self._ensure_uncommitted()
# Check whether it was written as part of this transaction
if path in self._updates:
return self._updates[path][0]
# Check whether we already have the request response
if path in self._get_queries:
return self._get_queries[path][0]
# Perform get request
# rev is from KVMetadata
val, rev = self._get_queries[path] = self.backend.get(
path, revision=self._revision
)
if self._revision is None:
self._revision = rev
return val
[docs]
def create(
self, path: str, value: str, lease: Optional[etcd3.Lease] = None
) -> None:
self._ensure_uncommitted()
value = str(value)
# Attempt to get the value - mainly to check whether it exists
# and put it into the query log
result = self.get(path)
if result is not None:
raise ConfigCollision(
path, f"Cannot create {path}, as it already exists!"
)
# Add update request
self._updates[path] = (value, lease)
# pylint: disable=duplicate-code
[docs]
def update(self, path: str, value: str) -> None:
self._ensure_uncommitted()
value = str(value)
result = self.get(path)
if result is None:
raise ConfigVanished(
path, f"Cannot update {path}, as it does not exist!"
)
# Add update request
self._updates[path] = (value, None)
# pylint: disable=too-many-arguments
[docs]
def delete(
self,
path: str,
must_exist: bool = True,
recursive: bool = False,
max_depth: int = 16,
prefix: bool = False,
) -> None:
keys = []
if prefix:
keys = self.list_keys(path, recurse=max_depth if recursive else 0)
else:
if self.get(path) is not None:
keys = [path]
if recursive:
keys += self.list_keys(path + "/", recurse=max_depth)
if must_exist and not keys:
raise ConfigVanished(
path, f"Cannot delete {path}, it does not exist!"
)
# Add delete request
for key in keys:
self._updates[key] = (None, None)
def _compare_list(self, txn: etcd3.Transactions) -> list:
# Create list to store revision comparisons to pass to
# compare operation in transaction
compare_list = []
# For every get call add revision comparison to compare list
for path, (_, rev) in self._get_queries.items():
tagged_path = _tag_depth(path)
if rev.revision is None:
# key did not exist? Verify it still doesn't exist.
# Note that the key could have been created and
# deleted in the meantime.
compare_list.append(txn.version(tagged_path) == 0)
else:
# Otherwise, add an assertion to the compare list that
# checks that the revision has not changed.
# This guarantees the key has not been modified
# since we last read it.
compare_list.append(txn.mod(tagged_path) < rev.revision + 1)
# Verify list_keys calls from the query log
for (path, depth), (result, rev) in self._list_queries.items():
tagged_path = _tag_depth(path, depth)
# check returned list of keys still exist
for res_path in result:
tagged_res_path = _tag_depth(res_path)
compare_list.append(txn.version(tagged_res_path) > 0)
# check no new keys have been added to the returned list
# by checking whether the request contains any keys with
# create revisions newer than the embedded revision of the
# request
tagged_path_end = etcd3.utils.prefix_range_end(tagged_path)
compare_list.append(
txn.create(tagged_path, tagged_path_end)
< self._revision.revision + 1
)
return compare_list
def _success_list(self, txn: etcd3.Transactions) -> list:
# Create list to store put and delete to pass to success
# operation in transaction
success_list = []
# For every update add a put or delete to the success list
for path, (value, lease) in self._updates.items():
tagged_path = _tag_depth(path)
lease_id = None if lease is None else lease.id
if value is None:
success_list.append(txn.delete(tagged_path, value))
else:
success_list.append(
txn.put(tagged_path, value, lease=lease_id)
)
return success_list
# pylint: disable=protected-access
[docs]
def commit(self) -> bool:
self._ensure_uncommitted()
# If we have made no updates, we don't need to verify the get query log
if not self._updates:
self._committed = True
return True
# Use the transaction from the etcd3 client
txn: etcd3.Transactions = self._client.transactions
# The client transaction method carries out the actions
# in the success_list if all assertions in the compare_list
# are true.
succeeded, _ = self.backend._retry_loop(
lambda: self._client.transaction(
compare=self._compare_list(txn),
success=self._success_list(txn),
failure=[],
)
)
# Done
self._committed = True
if succeeded:
for callback in self._commit_callbacks:
callback()
self._commit_callbacks = []
return succeeded
[docs]
def on_commit(self, callback: Callable[[], None]) -> None:
"""Register a callback to call when the transaction succeeds.
Exists mostly to enable test cases.
:param callback: Callback to call
"""
self._commit_callbacks.append(callback)
[docs]
def reset(self, revision: Optional[DbRevision] = None) -> None:
if not self._committed:
raise RuntimeError("Called reset on an uncommitted transaction!")
self._revision = revision
self._get_queries.clear()
self._list_queries.clear()
self._updates.clear()
self._committed = False
[docs]
def list_keys(self, path: str, recurse: RecurseType = 0) -> list[str]:
self._ensure_uncommitted()
# Walk through depths, collecting known keys
if isinstance(recurse, Iterable):
depth_iter = iter(recurse)
else:
depth_iter = range(recurse + 1)
keys: list[str] = []
for depth in depth_iter:
tagged_path = _tag_depth(path, depth_of_path(path) + depth)
matching_vals = [
kv_pair
for kv_pair in self._updates.items()
if _tag_depth(kv_pair[0]).startswith(tagged_path)
]
added_keys = {
key for key, val in matching_vals if val[0] is not None
}
removed_keys = {
key for key, val in matching_vals if val[0] is None
}
query = (path, depth + depth_of_path(path))
if query not in self._list_queries:
self._list_queries[query] = self.backend.list_keys(
path, recurse=(depth,)
)
# Add to key set
result, rev = self._list_queries[query]
keys.extend(set(result) - removed_keys | added_keys)
# Bake in revision if not already done so
if self._revision is None:
self._revision = rev
# Sort
return sorted(keys)
def __iter__(self) -> "Etcd3Transaction":
"""
Iterate transaction until it succeeds.
"""
try:
while self._retries <= self._max_retries:
# Should build up a transaction
yield self
# Try to commit, count how many times we have tried
if not self.commit():
self._retries += 1
else:
self._retries = 0
return
self.reset()
finally:
if self._updates and not self._committed:
LOGGER.warning(
"Transaction loop aborted - dropping updates to %s!",
list(self._updates.keys()),
)
raise RuntimeError(
f"Transaction did not succeed after {self._max_retries} retries!"
)