# pylint: disable=too-many-lines
"""
Etcd3 backend for SKA SDP configuration DB, implementating of
caching watcher.
"""
import logging
import queue
import threading
import time
from typing import Any, Callable, Optional
import etcd3
import grpc
from etcd3.etcdrpc import rpc_pb2, rpc_pb2_grpc
from etcd3.etcdrpc.kv_pb2 import Event
from .backend import DbRevision, Watcher
from .common import _tag_depth, _untag_depth, depth_of_path
from .etcd3 import Etcd3Backend, Etcd3Transaction
LOGGER = logging.getLogger(__name__)
# pylint: disable=too-many-instance-attributes
[docs]
class Etcd3Watcher(Watcher):
"""Watch for database changes by using nested transactions
Use as follows:
.. code-block:: python
for watcher in config.watcher():
for txn in watcher.txn():
# ... do something
for txn in watcher.txn():
# ... do something else
At the end of a for loop iteration, the watcher will start
watching all values read by transactions started through
:py:meth:`txn`, and only repeat the execution of the loop body
once one of these values has changed.
"""
# pylint: disable=too-many-arguments
def __init__(
self,
backend: "Etcd3Backend",
client: etcd3.Etcd3Client,
timeout: float = None,
txn_wrapper: Callable[[Etcd3Transaction], object] = None,
requery_progress: float = 0.2,
max_retries: int = 15,
retry_time: float = 0.1,
):
super().__init__(backend, timeout, txn_wrapper)
self._client = client
# All "get" and "list" queries done in an iteration
self._get_queries = set() # path
self._list_queries = set() # (depth, path)
# Whether any of the values above have been updated /
# overwritten in this iteration
self._dirty = False
# Connection to etcd server
self._conn: _Etcd3WatcherConnection = None
# Retry loop parameters:
# max_retries: Number of times we retry any database interaction
# retry_time: Initial back-off time after a failed
# database interaction, in seconds. Will be increased by 50%
# for every failed attempt.
self._max_retries = max_retries
self._retry_time = retry_time
# Current revision of the database. We want to keep this consistent
# with the values cached in the watches below, so this will only get
# updated once we are sure that we have processed all watch
# notifications for this revision. This is not entirely trivial to
# arrange, see refresh().
self._revision = None
self._requery_progress = requery_progress
self._last_progress = None
# Known database state at this revision. There is two ways
# we might know that a value does not exist:
# 1. _cache[path] is None
# 2. path not in _list_cache[depth(path)][prefix]
# where prefix is any prefix of "path" we are currently
# watching.
self._cache = {} # path -> value
self._list_cache = {} # depth -> prefix -> paths
# Cache consistent with _revision? Otherwise must refresh()
# before using it.
self._cache_needs_refresh = False
# Timeout for progress notifications (and therefore implicitly
# creation and cancellation requests)
self._request_timeout = 0.5
self._max_reconnects = 5
self._got_timeout = False # For test cases
# pylint: disable=cell-var-from-loop
def _start_watch_connection(self):
"""
Close existing server connection, then re-establishes it and
registers all watchers that existed previously.
"""
# Close existing watch request
get_queries = []
list_queries = []
if self._conn is not None:
get_queries = list(self._conn.get_watchers.keys())
list_queries = list(self._conn.list_watchers.keys())
self._conn.close()
# Create a new connection
self._retry_loop(
lambda: setattr(
self, "_conn", _Etcd3WatcherConnection(self._client.channel)
)
)
# Re-create all watchers on new connection
for path in get_queries:
self._retry_loop(
lambda: self._conn.create_get_watcher(path, self._revision)
)
for depth, prefix in list_queries:
self._retry_loop(
lambda: self._conn.create_list_watcher(
prefix, depth, self._revision
)
)
def _retry_loop(self, code_to_try: Callable) -> Any:
retry_time = self._retry_time
for i in range(self._max_retries):
def log_exception(ex, i):
LOGGER.warning(
"Caught %s, retry %d after %gs",
repr(ex),
i,
retry_time,
)
try:
return code_to_try()
except grpc.RpcError as ex:
log_exception(ex, i)
time.sleep(retry_time)
retry_time *= 1.5
# Attempt one final time - without safety net
return code_to_try()
def __iter__(self):
"""
Iterate forever, waiting after every interaction for something to
change.
"""
if self._conn is not None:
raise RuntimeError("Watcher is not re-entrant!")
LOGGER.debug("Entering watcher loop")
try:
# Create connection
self._start_watch_connection()
while True:
# Initialise loop iteration state
self._get_queries = set()
self._list_queries = set()
self._dirty = False
self._wake_up_at = None
# Yield to loop body
yield self
# Dirty - i.e. the iteration has internally changed a value
# that it read earlier? That means its state is internally
# inconsistent, and we should loop immediately. Let's
# especially keep watchers alive until we have a clean loop
# iteration.
if self._dirty:
continue
# Clean up any unused watchers
self._conn.close_unused(
self._get_queries,
self._list_queries,
self._cache,
self._list_cache,
self._revision,
)
self._get_queries = set()
self._list_queries = set()
# Determine timeout and wait until something happens
self._refresh(self.get_timeout())
finally:
LOGGER.debug("Exiting watcher loop")
if self._conn is not None:
self._conn.close()
self._conn = None
# pylint: disable=too-many-branches,too-many-statements
def _refresh(self, delay_if_no_change: float):
"""Checks incoming watch notifications and updates caches accordingly.
Returns revision that represents the current stream state.
After return, we guarantee that:
1) the returned revision is consistent with the cache
2) the cache is at most self._recheck_cache old
3) all in-flight watcher creations / cancellations have been done
Note that properly verifying that the cache is not stale might require
asking for a progress notification from the server, so if the last
refresh wasn't too long ago we will skip the refresh and simply
continue using the previous revision.
:param delay_if_no_change: Maximum delay to wait for something
to change. `None` means to wait forever.
:returns: Revision that is consistent with cache, or None
if there is no watcher active.
"""
# If there's no timeout, there's a couple cases where we can just keep
# self._revision without checking with the database server. Having a
# notification (presumably about some key we are watching) always
# overrides this.
if (
delay_if_no_change == 0
and not self._cache_needs_refresh
and self._conn.empty()
and self._last_progress is not None
):
# Skip refresh if self._revision is "recent enough"
next_requery = self._last_progress + self._requery_progress
if time.time() < next_requery:
return self._revision
# Current time bound for timeout purposes
if delay_if_no_change is not None:
if delay_if_no_change < 0:
raise ValueError("Timeout should not be negative!")
end_time = time.time() + delay_if_no_change
else:
end_time = None
# Loop state
requested_progress = 0
connection_resets = 0
if delay_if_no_change is None or delay_if_no_change > 0:
self._got_timeout = True
# Process messages
while True:
try:
# Get a response from the queue
response = self._conn.get_response(end_time)
# "Empty queue" means that we have processed all outstanding
# responses and reached the timeout, and therefore want to wrap up
# the refresh ASAP. Note that this might involve e.g. waiting for a
# requested progress notification, so "ASAP" might still involve
# more loops over the above code for receiving responses / waiting.
except queue.Empty:
# If there are no watchers, we don't need to request
# progress, and can simply return immediately
if (
not self._conn.get_watchers
and not self._conn.list_watchers
):
return None
# Once a progress notification has been sent, we expect the
# server to answer, so the next timeout is actually an error
if requested_progress > 0:
# We have to assume the connection has a problem - push an
# exception into the input queue, this will flow into the
# general handling of connection problems below.
self._conn.put(
RuntimeError(
"Timeout reached waiting for progress "
"notification from server!"
)
)
continue
# Otherwise send progress request
# Future note: Strictly speaking this is not required in all
# cases - e.g. if we only have one watcher and (!) have a
# sufficiently recent change event from the server and (!) no
# cancellations/creations in-flight, we could use the revision
# from said event. This clearly happens all the time in simple
# watcher loops that only watch one key - we wake up from a
# change event, and immediately send a redundant progress
# notification request. However, let's not introduce a special
# case until the code is sufficiently stable...
LOGGER.debug("Requesting progress...")
self._conn.request_progress()
requested_progress += 1
end_time = time.time() + self._request_timeout
continue
# A value of "None" means that trigger() was called, and
# we should return ASAP. Simply set the timeout to now.
if response is None:
if delay_if_no_change is None or delay_if_no_change > 0:
self._got_timeout = False
if requested_progress == 0:
end_time = time.time()
continue
# Re-raise exceptions from thread
if isinstance(response, Exception):
# If we are out of connection attempts re-raise to the
# caller
if connection_resets >= self._max_reconnects:
raise response
# Create new connection
LOGGER.warning("Connection dropped: %s", response)
self._start_watch_connection()
# We will need to re-request a progress notification on the new
# connection if appropriate
requested_progress = 0
connection_resets += 1
continue
# Global progress notification? We are synchronised!
if response.watch_id == -1 and not response.events:
LOGGER.debug(
"Received progress notification (revision %d)",
response.header.revision,
)
break
# Process events to populate cache
if self._process_events(response):
# Something changed, so return ASAP. If we have already
# requested a progress notification, we don't need to do
# anything else.
if not requested_progress:
end_time = time.time()
if delay_if_no_change is None or delay_if_no_change > 0:
self._got_timeout = False
# Last response should be a progress notification - only way to escape
# the loop above without raising an exception. Take over revision from
# response.
assert response.watch_id == -1 and not response.events
self._revision = DbRevision(response.header.revision)
self._last_progress = time.time()
self._cache_needs_refresh = False
return self._revision
def _process_events(self, response: rpc_pb2.WatchResponse):
"""
Update cache from events in a watch response.
Also sets self._dirty flag if any "used" watchers have
indicated a change.
"""
have_changes = False
for event in response.events:
# Put event?
# pylint: disable=no-member
if event.type == Event.EventType.PUT:
have_changes = True
self._process_put(
response.watch_id,
event.kv.key,
event.kv.value,
event.kv.mod_revision,
)
# Delete event?
# pylint: disable=no-member
elif event.type == Event.EventType.DELETE:
have_changes = True
self._process_delete(response.watch_id, event.kv.key)
else:
LOGGER.warning(
"Unexpected etcd3 event type, ignored: %s", event.type
)
return have_changes
def _process_put(self, watch_id, key, value, mod_revision):
path = _untag_depth(key)
depth = int(key[: key.index(b"/")])
LOGGER.debug(
"Watcher %d: put %s (rev %d)", watch_id, path, mod_revision
)
# Check whether we depended on this key earlier in
# this loop iteration
if path in self._get_queries:
LOGGER.debug(" ... dirty due to get request")
self._dirty = True
for prefix_depth, prefix in self._list_queries:
if depth == prefix_depth and path.startswith(prefix):
# Actually a new key, and not just a new value?
if path not in self._list_cache[prefix_depth][prefix]:
LOGGER.debug(
" ... dirty due to list request %s depth %d",
prefix,
prefix_depth,
)
self._dirty = True
break
# Update caches
self._cache[path] = value.decode("utf-8")
for prefix, list_paths in self._list_cache.get(depth, {}).items():
if path.startswith(prefix):
if path not in list_paths:
list_paths.append(path)
def _process_delete(self, watch_id, key):
path = _untag_depth(key)
depth = int(key[: key.index(b"/")])
LOGGER.debug("Watcher %d: delete %s", watch_id, key)
# Ignore if we never read the key in question
if not self._cache:
LOGGER.warning(
"Watcher %d: Deletion of unknown key %s!",
watch_id,
key,
)
return
# Check whether we depended on this key earlier in
# this loop iteration
if path in self._get_queries:
LOGGER.debug(" ... dirty due to get request")
self._dirty = True
for prefix_depth, prefix in self._list_queries:
if depth == prefix_depth and path.startswith(prefix):
LOGGER.debug(
" ... dirty due to list request %s depth %d",
prefix,
prefix_depth,
)
self._dirty = True
break
# Update caches
self._cache[path] = None
for prefix, list_paths in self._list_cache.get(depth, {}).items():
if path.startswith(prefix):
if path in list_paths:
list_paths.remove(path)
[docs]
def txn(self, max_retries: int = 64) -> Etcd3Transaction:
"""Create nested transaction.
The watcher loop will iterate when any value read by transactions
created by this method have changed in the database.
Note that these transactions otherwise behave exactly as normal
transactions: As long as they are internally consistent, they will be
commited. This means there is no consistency guarantees between
transactions created from the same watcher, i.e. one transaction might
read one value from the database while a later one reads another.
:param max_retries: Maximum number of times the transaction will be
tried before giving up.
"""
# Must be inside a watch loop
if self._conn is None:
raise RuntimeError(
"Can only be called inside a watcher loop iteration!"
)
# Make a new transaction. This will call back to refresh() in this
# class - possibly multiple times in case the transaction fails - to
# refresh the cache / current revision.
for txn in Etcd3TransactionWatcher(
self, self.backend, self._client, self._refresh(0), max_retries
):
if self._txn_wrapper is not None:
yield self._txn_wrapper(txn)
else:
yield txn
def _watch_get(self, path: str) -> bytes:
"""Get value of given path, and start watching it
:param path: Path to watch
:returns: Key value
"""
# Should only be called from transactions, which will call
# refresh() on every iteration.
assert not self._cache_needs_refresh
# Note that we had this query so we can determine which
# watchers we can cancel later
self._get_queries.add(path)
# Check whether the value is in cache. This is going to be the
# case if either
# 1) We have a "get" watcher watching this specific key
# 2) The key exists *and* we have a "list" watcher covering it
if path in self._cache:
return self._cache[path]
# Check whether the value matches any list query - i.e. can we deduce
# from the fact that it was *not* observed as answer to a list query
# that it does not exist? This is actually important to check, because
# otherwise we might generate redundant "get" watches below.
depth = depth_of_path(path)
for prefix in self._list_cache.get(depth, {}).keys():
if path.startswith(prefix):
# We don't need to check the actual cached list of keys here -
# if our key was in the list, it would have been in the
# cache. However, we need to add it to the cache so that both
# we can skip this check next time, and so that we will know to
# create a get watcher if the list watcher goes away.
self._cache[path] = None
return None
# Perform get request, populating cache
val, rev = self.backend.get(path, revision=self._revision)
self._cache[path] = val
if self._revision is None:
self._revision = rev
# Start watcher to ensure we get updates for its value from
# here on out.
self._conn.create_get_watcher(path, self._revision)
# Done
return val
def _watch_list(self, prefix: str, depth: int):
"""List keys in given path, and start watching it
Note that this will implicitly *also* get all values of the
listed keys, and start watching them.
:param path: Path to watch
:param depth: Depth to query
:returns: List of keys
"""
self._list_queries.add((depth, prefix))
# Are we watching this specific list query?
if depth not in self._list_cache:
list_cache = self._list_cache[depth] = {}
else:
list_cache = self._list_cache[depth]
if prefix in list_cache:
return list_cache[prefix]
# Perform list request
path_depth = prefix.count("/")
paths_and_values, rev = self.backend.list_keys(
prefix,
revision=self._revision,
with_values=True,
recurse=(depth - path_depth,),
)
# Populate caches
for path, val, rev in paths_and_values:
self._cache[path] = val
keys = [kvs[0] for kvs in paths_and_values]
if self._revision is None:
self._revision = rev
list_cache[prefix] = keys
# Create watcher
self._conn.create_list_watcher(prefix, depth, self._revision)
# Return
return keys
[docs]
def trigger(self):
"""Manually triggers a loop
Can be called from a different thread to force a loop, even if
the watcher is currently waiting.
"""
if self._conn is None:
raise RuntimeError("No watcher loop running!")
self._conn.put(None)
class _Etcd3WatcherConnection:
"""
Internal class for tracking etcd3 watcher connection state
Tracks active watchers
"""
def __init__(self, channel):
"""
Set up connection to etcd3 server
This especially sets up queues that can be used to send and
receive messages through the connection.
:param channel: gRPC channel to use
"""
# Set up a queue for sending requests
watch_stub = rpc_pb2_grpc.WatchStub(channel)
self._out_queue = queue.Queue()
self._responses = watch_stub.Watch(iter(self._out_queue.get, None))
# Start a thread to send responses to a queue as well
self._in_queue = queue.Queue()
def _in_thread():
try:
for response in self._responses:
self._in_queue.put(response)
# pylint: disable=broad-exception-caught
except Exception as exc:
self._in_queue.put(exc)
self._in_thread = threading.Thread(target=_in_thread)
self._in_thread.start()
# Set remaining state
self._next_watch_id = 0
self._created_watchers = set()
self._cancelled_watchers = set() # watch IDs
self.get_watchers = {} # path -> watch ID
self.list_watchers = {} # (depth, path) -> watch ID
self._synchronising = False # progres notification on the way?
def close(self):
"""
Closes the connection
The connection object is unusable after this
"""
self._out_queue.put(None)
self._responses.cancel()
self._in_thread.join()
def create_get_watcher(self, path: str, revision: DbRevision):
"""Creates a new watcher watching a specific key
:param path: Path to watch
:param revision: Database revision at which to start watching
:returns: watcher ID
"""
# Do not create duplicated get watchers
assert path not in self.get_watchers
assert not self._synchronising
# Allocate a fresh watcher ID
key = _tag_depth(path)
watch_id = self._next_watch_id
self._next_watch_id += 1
LOGGER.debug(
"Start get watcher %d for %s at rev %d...",
watch_id,
key.decode("utf-8"),
revision.revision,
)
# Send watch creation message via queue.
# gRPC packaged with etcd3-python doesn't support watch_id yet - it
# appears in later protocol versions. So we just need to assume that we
# can predict the watcher ID...
self._out_queue.put(
rpc_pb2.WatchRequest(
create_request=rpc_pb2.WatchCreateRequest(
key=key,
range_end=key + b"\0",
start_revision=revision.revision + 1,
# watch_id=watch_id,
)
)
)
# Register watcher. We assume its value was used.
self.get_watchers[path] = watch_id
self._created_watchers.add(watch_id)
return watch_id
def create_list_watcher(
self, prefix: str, depth: int, revision: DbRevision
):
"""Creates a new watcher watching all keys with a prefix at a given
depth
:param prefix: Path prefix to watch
:param depth: Restrict to paths at given depth
:param revision: Database revision at which to start watching
:returns: watcher ID
"""
assert (depth, prefix) not in self.list_watchers
assert not self._synchronising
# Allocate a fresh watcher ID
key = _tag_depth(prefix, depth)
watch_id = self._next_watch_id
self._next_watch_id += 1
LOGGER.debug(
"Start list watcher %d for %s (depth %d) at rev %d...",
watch_id,
key.decode("utf-8"),
depth,
revision.revision,
)
# Send watch creation message via queue
# gRPC packaged with etcd3-python doesn't support watch_id yet - it
# appears in later protocol versions. So we just need to assume that we
# can predict the watcher ID...
end_key = key[:-1] + bytes([key[-1] + 1])
self._out_queue.put(
rpc_pb2.WatchRequest(
create_request=rpc_pb2.WatchCreateRequest(
key=key,
range_end=end_key,
start_revision=revision.revision + 1,
# watch_id=watch_id,
)
)
)
# Register watcher. We assume its value was used.
self.list_watchers[(depth, prefix)] = watch_id
self._created_watchers.add(watch_id)
def put(self, obj):
"""
Put object at back of input queue such that get_response()
eventually returns it.
This especially works cross-thread
"""
self._in_queue.put(obj)
def request_progress(self):
"""
Send progress notification to server.
This puts the connection into a "synchronising" state where
we cannot create or cancel watchers any more until the
response from the server has been received.
"""
self._out_queue.put(
rpc_pb2.WatchRequest(
progress_request=rpc_pb2.WatchProgressRequest()
)
)
self._synchronising = True
def empty(self):
"""
Checks whether there are any responses waiting
If False, it is likely (!) that get_response would
not return any response
"""
return self._in_queue.empty()
# pylint: disable=too-many-branches
def get_response(self, end_time: Optional[float] = None):
"""
Get next response from connection
:param end_time: Time at which to return
:returns: Either a Response object from the connection, an
exception raised while receiving the message, or
:raises queue.Empty: If timeout was reached
"""
while True:
# Wait for designated time, if an end time is given
if end_time is None:
response = self._in_queue.get()
else:
timeout = max(0, end_time - time.time())
# Might raise queue.Empty to caller
response = self._in_queue.get(timeout=timeout)
# Non-responses (exceptions etc) get returned as-is.
if not isinstance(response, rpc_pb2.WatchResponse):
return response
# Trying to start watcher at compacted revision?
if response.compact_revision:
# This is a fairly specialised case - this will only happen if
# we create a new watcher on a revision that has in the
# meantime be compacted by the server. This effectively means
# that our current "revision" is so far out of date that the
# server cannot guarantee any more that they can send us all
# events from that revision. This should be an exceedingly rare
# condition (and generally fixable simply by setting a
# sufficiently low timeout on the watcher). However, we do need
# to catch it, because we might otherwise lose data.
raise RuntimeError(
f"Watcher {response.watch_id} could not be created - "
f"compacted from revision {response.compact_revision}!"
)
# Creation?
if response.created:
if response.watch_id not in self._created_watchers:
raise RuntimeError(
f"Watcher {response.watch_id} created unexpectedly!"
)
LOGGER.debug("Watcher %d created", response.watch_id)
self._created_watchers.remove(response.watch_id)
continue
# Cancellation?
if response.canceled:
if response.watch_id not in self._cancelled_watchers:
raise RuntimeError(
f"Watcher {response.watch_id} cancelled unexpectedly "
f"({response.cancel_reason})"
)
LOGGER.debug(
"Watcher %d cancelled (%s)",
response.watch_id,
response.cancel_reason,
)
self._cancelled_watchers.remove(response.watch_id)
continue
# Ignore responses from watchers we are currently cancelling
if response.watch_id in self._cancelled_watchers:
continue
# A progress notification should especially have flushed out
# all watcher creation and cancellations
if response.watch_id == -1 and not response.events:
self._synchronising = False
if self._created_watchers:
LOGGER.warning(
"Missing notifications for created watchers: %s",
self._created_watchers,
)
if self._cancelled_watchers:
LOGGER.warning(
"Missing notifications for cancelled watchers: %s",
self._cancelled_watchers,
)
# Otherwise return the response
return response
# pylint: disable=too-many-arguments
def close_unused(
self,
get_queries: set[str],
list_queries: set[(int, str)],
get_cache: dict[str, str],
list_cache: dict[int, dict[str, list[str]]],
revision: DbRevision,
):
"""
Close any watchers not covered by the given queries, and
removes uncovered cache entries
This might actually involve creating new watchers, e.g. if
we close a list watcher, but still need to cover get queries.
:param get_queries: Set of paths
:param list_queries: Set of (depth, prefix) pairs
:param get_cache: Cache of get results
:param list_cache: Cache of list results
:param revision: Revision to start any new watchers at
"""
# Collect watchers to cancel
cancel_watchers = set()
# Caches to keep
protect_caches = set()
# First check all list watchers for usage
for (depth, prefix), watch_id in list(self.list_watchers.items()):
# Find get cache entries covered by list
covered_keys = [
path
for path in get_cache.keys()
if path.startswith(prefix) and depth_of_path(path) == depth
]
# Queried the list?
if (depth, prefix) in list_queries:
# Clear all get queries for keys covered by the
# list watcher so that we won't keep "get"
# watchers alive for them below
get_overlap = get_queries & set(covered_keys)
get_queries -= get_overlap
protect_caches |= get_overlap
continue
LOGGER.debug(
"Cancelling list watcher %d (%s, depth %d)",
watch_id,
prefix,
depth,
)
# Cancel list watch
cancel_watchers.add(watch_id)
del self.list_watchers[(depth, prefix)]
del list_cache[depth][prefix]
# Clear all cache entries with this prefix - if they weren't read
# this iteration, and therefore would get a replacement "get"
# watcher below.
# Future note: Make this more efficient - using an ordered
# dictionary for _cache would allow us to skip iterating over the
# whole cache here...
for path in covered_keys:
if path not in get_queries:
del get_cache[path]
protect_caches.add(path)
elif path not in self.get_watchers:
# Create replacement "get" watcher for the path
self.create_get_watcher(path, revision)
# Next check all get watchers - this is a simpler check: Every
# get watcher should have a corresponding get query that
# wasn't covered by a list watcher already.
for path, watch_id in list(self.get_watchers.items()):
if path not in get_queries:
LOGGER.debug("Cancelling get watcher %d (%s)", watch_id, path)
if path not in protect_caches:
del get_cache[path]
del self.get_watchers[path]
cancel_watchers.add(watch_id)
# Send cancel requests
for watch_id in cancel_watchers:
rpc_pb2.WatchCancelRequest(watch_id=watch_id)
self._out_queue.put(
rpc_pb2.WatchRequest(
cancel_request=rpc_pb2.WatchCancelRequest(
watch_id=watch_id
)
)
)
# Safe back that we cancelled them
self._cancelled_watchers = self._cancelled_watchers | cancel_watchers
# pylint: disable=protected-access
[docs]
class Etcd3TransactionWatcher(Etcd3Transaction):
"""A series of queries and updates to be executed atomically.
This transaction offers the same interface as Etcd3Transaction,
but utilises the watcher's cache to return results more
efficiently.
"""
# pylint: disable=too-many-arguments
def __init__(
self,
watcher: Etcd3Watcher,
backend: Etcd3Backend,
client: etcd3.Etcd3Client,
revision: DbRevision,
max_retries: int = 64,
):
"""Initialise transaction."""
# Initialise
self._watcher = watcher
super().__init__(backend, client, max_retries)
self._revision = revision
[docs]
def get(self, path: str) -> str:
"""
Get value of a key.
:param path: Path of key to query
:returns: Key value. None if it doesn't exist.
"""
self._ensure_uncommitted()
# Check whether it was written as part of this transaction
if path in self._updates:
return self._updates[path][0]
# Go through watcher (cache) to query actual value. Note
# modification revision for later ensuring consistency.
assert self._watcher._revision == self._revision
val = self._watcher._watch_get(path)
if self._revision is None:
self._revision = self._watcher._revision
# Save back for the purpose of transaction's commit()
self._get_queries[path] = (val, self._revision)
return val
# pylint: disable=duplicate-code
[docs]
def list_keys(self, path: str, recurse: int = 0):
"""
List keys under given path.
:param path: Prefix of keys to query. Append '/' to list
child paths.
:param recurse: Children depths to include in search
:returns: sorted key list
"""
self._ensure_uncommitted()
path_depth = path.count("/")
# Walk through depths, collecting known keys
try:
depth_iter = iter(recurse)
except TypeError:
depth_iter = range(recurse + 1)
keys = []
for depth in depth_iter:
# We might have created or deleted an uncommitted key that
# falls into the range - add to list
tagged_path = _tag_depth(path, path_depth + depth)
matching_vals = [
kv
for kv in self._updates.items()
if _tag_depth(kv[0]).startswith(tagged_path)
]
added_keys = {
key for key, val in matching_vals if val[0] is not None
}
removed_keys = {
key for key, val in matching_vals if val[0] is None
}
# Check whether we need to perform the request
query = (path, depth + path_depth)
if query in self._list_queries:
result = self._list_queries[query][0]
else:
result = self._watcher._watch_list(path, path_depth + depth)
if self._revision is None:
self._revision = self._watcher._revision
self._list_queries[query] = (result, self._revision)
# Add to key set
keys.extend(set(result) - removed_keys | added_keys)
# Sort
return sorted(keys)
[docs]
def commit(self) -> bool:
"""
Commit the transaction to the database.
This can fail, in which case the transaction must get `reset`
and built again.
:returns: Whether the commit succeeded
"""
# Attempt to commit
if not super().commit():
LOGGER.debug("Commit failed")
# Failed? Must re-attempt with a new revision - this
# requires a refresh
self._revision = self._watcher._refresh(0)
return False
# Transaction did updates? Then we need to force a refresh:
# Firstly, we read basically every value we write first,
# therefore our cache is now out of date. Furthermore, even
# though we might know the values we wrote, this doesn't help
# us - in the new database revision other values might have
# changed, and there's no efficient way of verifying that
# other than flushing and synchronising all other watchers.
if self._updates:
# Mark cache for refresh
self._watcher._cache_needs_refresh = True
LOGGER.debug("Revision changed - need refresh")
return True
[docs]
def reset(self, revision: Optional[DbRevision] = None) -> None:
"""
Reset the transaction, so it can be restarted after commit()
"""
if revision is not None and revision != self._revision:
raise ValueError("Watcher transaction cannot override revision!")
# Do not reset revision - it must be the same as the watcher
assert self._revision == self._watcher._revision
super().reset(self._revision)