Source code for ska_sdp_dal.connection

import abc
import logging
import select
import time
import weakref
from typing import Dict, List, Mapping, Tuple, Union

import awkward
import numpy
import pyarrow
import pyarrow.plasma as plasma

from . import common

logger = logging.getLogger(__name__)


[docs]class Connection(object): """ A connection to a Plasma store. Subscribes to events and uses it to maintain a list of objects in the store. We especially track namespaces. """ def __init__(self, plasma_path: str): # Connect to Plasma self._client = plasma.connect(plasma_path) # Initialise buffer cache self._buf_cache = weakref.WeakValueDictionary() # Subscribe to event updates self._client.subscribe() self._socket = self._client.get_notification_socket() # Get list of objects self._obj_table = self._client.list() self._ns_table_meta = {} self._ns_table_procs = {} for oid in self._obj_table: if common.is_namespace_decl(oid): self._parse_namespace(oid) @property def client(self) -> plasma.PlasmaClient: return self._client @property def namespaces(self) -> List[plasma.ObjectID]: return list(self._ns_table_meta) @property def namespace_meta(self) -> Dict[plasma.ObjectID, Dict[bytes, bytes]]: return self._ns_table_meta @property def namespace_procs(self) -> Dict[plasma.ObjectID, List[pyarrow.Schema]]: return self._ns_table_procs
[docs] def object_exists(self, oid) -> bool: """Checks whether the given object ID is known to exist :param oid: Object ID to check """ return oid in self._obj_table
[docs] def object_size(self, oid) -> bool: """Gets the size of the given object :param oid: Object ID to check """ obj_record = self._obj_table.get(oid) if obj_record is None: return None return obj_record["data_size"]
[docs] def get_buffers(self, oids, timeout=None) -> plasma.PlasmaBuffer: """Retrieve object for given OIDs. Uses a cache to prevent duplicated requests to the Plasma store. :param oids: Plasma object IDs :param timeout: Time to wait for buffers to become available :returns: Plasma buffer """ # Check what OIDs to request oids_to_request = [oid for oid in oids if oid not in self._buf_cache] if oids_to_request: bufs = self._client.get_buffers( oids, -1 if timeout is None else timeout ) # Add to cache for oid, buf in zip(oids_to_request, bufs): if buf is not None: self._buf_cache[oid] = buf # Return all buffers return [self._buf_cache.get(oid) for oid in oids]
[docs] def update_obj_table(self, timeout: float = 0): """Update the object table :param timeout: If given, allow blocking for up to the given time or until the next update happens. :returns: A list of received update notifications """ notifications = [] while True: # Get next notification. This will block. Using select, # this way we can implement a timeout. rs, _, _ = select.select([self._socket.fileno()], [], [], timeout) if len(rs) == 0: break notification = self._client.get_next_notification() notifications.append(notification) oid, data_size, metadata_size = notification # Deleted? if data_size < 0 or metadata_size < 0: if oid not in self._obj_table: continue del self._obj_table[oid] if common.is_namespace_decl(oid): if oid in self._ns_table_meta: del self._ns_table_meta[oid] if oid in self._ns_table_procs: del self._ns_table_procs[oid] continue # Set in object table self._obj_table[oid] = { "data_size": data_size, "metadata_size": metadata_size, } # Namespace? if common.is_namespace_decl(oid): try: self._parse_namespace(oid) except Exception as e: # Ignore anything going wrong here, we don't want # this to break other code logger.warning("Failed to parse namespace: %s", e) # Reset timeout (process any remaining notifications) timeout = 0 return notifications
def _parse_namespace(self, oid: plasma.ObjectID): # Get buffer, copy. We do not want to leave behind trailing # references to namespace objects. buf_raw = self.get_buffers([oid], 0)[0] if buf_raw is None: # vanished before we got to it, simply ignore return buf = pyarrow.py_buffer(buf_raw) procs = {} meta = {} if buf.size > 0: # Parse reader = pyarrow.ipc.open_stream(pyarrow.BufferReader(buf)) if reader.schema.metadata is not None: meta = reader.schema.metadata # Read supported calls name_col = reader.schema.get_field_index("name") schema_col = reader.schema.get_field_index("schema") for batch in reader: iters = [iter(col) for col in batch.columns] for _ in range(batch.num_rows): name = next(iters[name_col]).as_py() schema_buf = next(iters[schema_col]).as_buffer() # Attempt to read schema. We can use a record batch reader # for this - a serialised schema looks just like a record # batch stream without any actual record batches. call_schema = pyarrow.ipc.open_stream( pyarrow.BufferReader(schema_buf) ).schema procs[name] = call_schema # Save self._ns_table_meta[oid] = meta self._ns_table_procs[oid] = procs
[docs] def reserve_namespace( self, name: str = None, procs: List[pyarrow.Schema] = [], prefix: bytes = b"", ) -> Tuple[bytes, plasma.PlasmaBuffer]: """Reserve a new namespace within the Plasma store This will automatically clear all objects with the given prefix :param name: Informative display name for namespace :param procs: Call schemas supported (if any) :param name: Metadata to associate with schema :param prefix: Prefix for prefix :returns: Prefix, buffer with declaration (to keep namespace alive) """ # Make set of used prefixes self.update_obj_table() used_prefixes = { oid.binary()[: common.NAMESPACE_ID_SIZE] for oid in self._obj_table.keys() } # Serialise namespace declaration decl = common._make_namespace_decl(name, procs) stream = pyarrow.BufferOutputStream() writer = pyarrow.ipc.RecordBatchStreamWriter(stream, decl.schema) writer.write(decl) writer.close() data = stream.getvalue().to_pybytes() # Reserve a free namespace for _prefix in common.objectid_generator( prefix, common.NAMESPACE_ID_SIZE ): if _prefix in used_prefixes: continue try: # Attempt to write into Plasma oid = plasma.ObjectID(next(common.objectid_generator(_prefix))) self._client.create_and_seal(oid, data, b"") # Immediately delete without releasing (-> normally # will get deleted once client connection closes) root = self.get_buffers([oid])[0] self._client.delete([oid]) # Request deletion of all objects in new namespace for oid in self._obj_table: if oid.binary().startswith(_prefix): self._client.delete(oid) # Found our prefix! return _prefix, root except plasma.PlasmaObjectExists: continue # Nothing found! raise RuntimeError("Could not find a free Plasma namespace!")
[docs] def get_ref_buffers( self, refs: List["Ref"], timeout: float = None, auto_delete: bool = True, ) -> None: """ Retrieves the buffers for multiple Plasma references at a time. Blocks as long as any (!) of the objects have not been created yet. :param refs: References to retrieve buffer of :param timeout: Maximum time this function is allowed to block. :param auto_delete: Delete object in store when reference is dropped :raises: TimeoutException """ # Collect references not resolved yet unresolved = [ref for ref in refs if ref._buf is None] if len(unresolved) == 0: return # Wait for objects to become available waiting = list(unresolved) start = time.time() while waiting: # Remove available objects from waiting list still_waiting = [ ref for ref in waiting if ref._oid not in self._obj_table ] if len(still_waiting) == 0: break # Wait for an update on the object table if timeout is not None: dt = start + timeout - time.time() if dt <= 0: raise TimeoutException(still_waiting, timeout) self.update_obj_table(dt) else: self.update_obj_table() # Batch-request objects and make temporary oids = [ref._oid for ref in unresolved] bufs = self.get_buffers(oids, 0) if auto_delete: # As we are holding buffer references, this will only # actually happen once no buffer reference exists any more self._client.delete(oids) for buf, ref in zip(bufs, unresolved): ref._buf = buf # Final check: It's possible that an object vanished in the # mean time. We only do this after the above loop to make sure # we retain as many references as possible. for ref in unresolved: if ref._buf is None: raise RuntimeError( f"Object {common.object_id_hex(ref._oid)} vanished " "before it could be read!" )
[docs]class Ref(metaclass=abc.ABCMeta): """Refers to an object in storage Subclassed by type. Might not have been created yet. Can have two kinds of relationships with other objects: * dependency: Object is required for this object to be created. Must ensure objects stay alive *until* this object is found to be created. Typically refers to call objects. * reference: Object that is referenced from this object and must therefore be kept alive while this object is still needed. """ def __init__( self, conn: Connection, oid: plasma.ObjectID, auto_delete: bool = True, dependencies: List["Ref"] = [], references: List["Ref"] = [], ): self._conn = conn self._oid = oid self._dependencies = [] self._references = [] for dep in dependencies: self.add_dependency(dep) for ref in references: self.add_reference(ref) self._auto_delete = auto_delete self._buf = None @property def oid(self) -> plasma.ObjectID: return self._oid
[docs] def add_dependency(self, ref: "Ref", timeout: float = 0) -> None: """Registers the identified object as a dependency. This will ensure that the object is kept alive *until* we have retrieved the data for this object. Blocks if the object does not yet exist in the store. :param ref: Reference to register as dependency :param timeout: Maximum time this method is allowed to block. :raises: TimeoutException """ # Construct reference, if needed if not isinstance(ref, Ref): ref = Ref(self._conn, ref) # Ensure buffer references store (might raise timeout exception) self._conn.get_ref_buffers([self], timeout, self._auto_delete) # Add dependency self._dependencies.append(ref)
[docs] def add_reference( self, ref: Union["Ref", plasma.ObjectID], timeout: float = 0 ) -> None: """Registers the identified object as referenced This will ensure that the object is kept alive as long as this object is referenced. Might Block if the object does not yet exist in the Plasma store. :param ref: Reference to register as referenced :param timeout: Maximum time this method is allowed to block. :raises: TimeoutException """ # Construct reference, if needed if not isinstance(ref, Ref): ref = Ref(self._conn, ref) # Ensure buffer references store (might raise timeout exception) self._conn.get_ref_buffers([self], timeout, self._auto_delete) # Add dependency self._references.append(ref)
[docs] def get_buffer( self, timeout: float = None, auto_delete: bool = True ) -> pyarrow.Buffer: """ Get Arrow buffer for this Plasma reference. Blocks if the object has not yet been created. :param timeout: Maximum time this method is allowed to block. :param auto_delete: Delete object in store when reference is dropped :raises: TimeoutException """ # Return buffer, if present if self._buf is not None: return self._buf # Retrieve buffer self._conn.get_ref_buffers([self], timeout, self._auto_delete) # Clear dependencies self._dependencies = [] return self._buf
def __str__(self): return f"{self.__class__.__name__}({common.object_id_hex(self._oid)})"
[docs]class TimeoutException(Exception): def __init__(self, refs, timeout): self._refs = refs self._timeout = timeout @property def refs(self): return self._refs @property def timeout(self): return self._timeout def __str__(self): oids = [common.object_id_hex(ref.oid) for ref in self.refs] if self._timeout > 0: return ( f"Object{'s' if len(oids) > 1 else ''} {','.join(oids)} did " f"not appear in Plasma store within {self._timeout}s!" ) else: return ( f"Object{'s' if len(oids) > 1 else ''} {','.join(oids)} do " "not exist in Plasma store!" )
[docs]class TensorRef(Ref): """Refers to a tensor in object storage. Might not have been created yet - wraps an Object ID and expected type information. """ def __init__( self, conn: Connection, oid: plasma.ObjectID, typ: pyarrow.DataType = None, dim_names: List[str] = None, auto_delete: bool = True, ): self._typ = typ self._dim_names = dim_names super(TensorRef, self).__init__(conn, oid, auto_delete) @property def typ(self) -> pyarrow.DataType: return self._typ @property def dim_names(self) -> List[str]: return self._dim_names
[docs] def put(self, arr: numpy.ndarray = None): """Write the given value into storage. :param arr: Array to write to storage. Empty by default. """ # If no array given, set an empty numpy array if arr is None: arr = numpy.empty(tuple([0] * len(self._dim_names))) self._buf = common._put_numpy( self._conn.client, self._oid, arr, self._typ, self._dim_names, self._auto_delete, )
[docs] def get(self, timeout=None): """Retrieve the tensor from storage. Might block. :param timeout: Maximum time this method is allowed to block.""" # Read return common._tensor_from_buf( self.get_buffer(timeout), common.object_id_hex(self._oid), self._typ, self._dim_names, )
[docs]class TableRef(Ref): """Refers to a Table in Plasma store. Might not have been created yet - wraps an Object ID and expected type information. """ def __init__( self, conn: Connection, oid: plasma.ObjectID, schema: pyarrow.Schema = None, auto_delete: bool = True, ): self._schema = schema self._table = None super(TableRef, self).__init__(conn, oid, auto_delete) @property def schema(self) -> pyarrow.Schema: return self._schema
[docs] def put( self, table: Union[ pyarrow.Table, Mapping[str, pyarrow.ChunkedArray], Mapping[str, pyarrow.Array], Mapping[str, list], ], max_chunksize=None, ): """Write the given table into storage. The table can be given as pandas :py:class:`DataFrame`, dictionary of strings to :py:class:`pyarrow.Array` or lists, which will be converted into the equivalent table. If the arrays are chunked, the chunks of all columns must match. See also :py:meth:`pyarrow.table`. :param table: Table to write. :param max_chunksize: Maximum size of record batches to split table into """ # A table? if isinstance(table, pyarrow.Table): # Check schema compatibility if self._schema is not None: problems = common.schema_compatible(self._schema, table.schema) if problems: raise TypeError( "Wrong table schema: " + ", ".join(problems) ) else: # Convert into table table = pyarrow.table(table, schema=self._schema) # Get record batches, determine size sum batches = table.to_batches(max_chunksize) record_batch_size = sum( map(pyarrow.ipc.get_record_batch_size, batches) ) # Determine base size required for headers (unfortunately # there doesn't seem to be a more direct way). Cache? test_stream = pyarrow.BufferOutputStream() pyarrow.RecordBatchStreamWriter(test_stream, self._schema).close() table_size = test_stream.tell() + record_batch_size # Reserve in Plasma self._buf = self._conn.client.create(self._oid, table_size) # Write out batches buf_writer = pyarrow.FixedSizeBufferWriter(self._buf) writer = pyarrow.RecordBatchStreamWriter(buf_writer, self._schema) for batch in batches: writer.write_batch(batch) writer.close() # Check that prediction was right - if this breaks we want to # learn about it early. if buf_writer.tell() != table_size: logger.warning( "Unexpected size of serialised table: " f"{buf_writer.tell()} != {table_size}!" ) # Seal + request deletion self._conn.client.seal(self._oid) self._conn.client.delete([self._oid])
[docs] def get(self, timeout=None): """Get Arrow table :param timeout: How long the function is allowed to block if object does not exist yet """ # Table already read? if self._table is not None: return self._table # Read table from buffer buf_reader = pyarrow.BufferReader(self.get_buffer(timeout)) reader = pyarrow.ipc.RecordBatchStreamReader(buf_reader) self._table = reader.read_all() return self._table
[docs] def get_awkward(self): """Get table as Awkward array This can generally be done without copying the data. """ return awkward.from_arrow(self.get())
[docs] def get_pandas(self, *args, **kwargs): """Get table as Pandas dataframe :param kwargs: Parameters to panda conversion. See :py:func:`pyarrow.Table.to_pandas`. """ return self.get().to_pandas(*args, **kwargs)
[docs] def get_dict(self): """Get table as Python dictionary""" return self.get().to_pydict()