import abc
import logging
import select
import time
import weakref
from typing import Dict, List, Mapping, Tuple, Union
import awkward
import numpy
import pyarrow
import pyarrow.plasma as plasma
from . import common
logger = logging.getLogger(__name__)
[docs]class Connection(object):
"""
A connection to a Plasma store.
Subscribes to events and uses it to maintain a list of objects in
the store. We especially track namespaces.
"""
def __init__(self, plasma_path: str):
# Connect to Plasma
self._client = plasma.connect(plasma_path)
# Initialise buffer cache
self._buf_cache = weakref.WeakValueDictionary()
# Subscribe to event updates
self._client.subscribe()
self._socket = self._client.get_notification_socket()
# Get list of objects
self._obj_table = self._client.list()
self._ns_table_meta = {}
self._ns_table_procs = {}
for oid in self._obj_table:
if common.is_namespace_decl(oid):
self._parse_namespace(oid)
@property
def client(self) -> plasma.PlasmaClient:
return self._client
@property
def namespaces(self) -> List[plasma.ObjectID]:
return list(self._ns_table_meta)
@property
def namespace_meta(self) -> Dict[plasma.ObjectID, Dict[bytes, bytes]]:
return self._ns_table_meta
@property
def namespace_procs(self) -> Dict[plasma.ObjectID, List[pyarrow.Schema]]:
return self._ns_table_procs
[docs] def object_exists(self, oid) -> bool:
"""Checks whether the given object ID is known to exist
:param oid: Object ID to check
"""
return oid in self._obj_table
[docs] def object_size(self, oid) -> bool:
"""Gets the size of the given object
:param oid: Object ID to check
"""
obj_record = self._obj_table.get(oid)
if obj_record is None:
return None
return obj_record["data_size"]
[docs] def get_buffers(self, oids, timeout=None) -> plasma.PlasmaBuffer:
"""Retrieve object for given OIDs.
Uses a cache to prevent duplicated requests to the Plasma store.
:param oids: Plasma object IDs
:param timeout: Time to wait for buffers to become available
:returns: Plasma buffer
"""
# Check what OIDs to request
oids_to_request = [oid for oid in oids if oid not in self._buf_cache]
if oids_to_request:
bufs = self._client.get_buffers(
oids, -1 if timeout is None else timeout
)
# Add to cache
for oid, buf in zip(oids_to_request, bufs):
if buf is not None:
self._buf_cache[oid] = buf
# Return all buffers
return [self._buf_cache.get(oid) for oid in oids]
[docs] def update_obj_table(self, timeout: float = 0):
"""Update the object table
:param timeout: If given, allow blocking for up to the given
time or until the next update happens.
:returns: A list of received update notifications
"""
notifications = []
while True:
# Get next notification. This will block. Using select,
# this way we can implement a timeout.
rs, _, _ = select.select([self._socket.fileno()], [], [], timeout)
if len(rs) == 0:
break
notification = self._client.get_next_notification()
notifications.append(notification)
oid, data_size, metadata_size = notification
# Deleted?
if data_size < 0 or metadata_size < 0:
if oid not in self._obj_table:
continue
del self._obj_table[oid]
if common.is_namespace_decl(oid):
if oid in self._ns_table_meta:
del self._ns_table_meta[oid]
if oid in self._ns_table_procs:
del self._ns_table_procs[oid]
continue
# Set in object table
self._obj_table[oid] = {
"data_size": data_size,
"metadata_size": metadata_size,
}
# Namespace?
if common.is_namespace_decl(oid):
try:
self._parse_namespace(oid)
except Exception as e:
# Ignore anything going wrong here, we don't want
# this to break other code
logger.warning("Failed to parse namespace: %s", e)
# Reset timeout (process any remaining notifications)
timeout = 0
return notifications
def _parse_namespace(self, oid: plasma.ObjectID):
# Get buffer, copy. We do not want to leave behind trailing
# references to namespace objects.
buf_raw = self.get_buffers([oid], 0)[0]
if buf_raw is None:
# vanished before we got to it, simply ignore
return
buf = pyarrow.py_buffer(buf_raw)
procs = {}
meta = {}
if buf.size > 0:
# Parse
reader = pyarrow.ipc.open_stream(pyarrow.BufferReader(buf))
if reader.schema.metadata is not None:
meta = reader.schema.metadata
# Read supported calls
name_col = reader.schema.get_field_index("name")
schema_col = reader.schema.get_field_index("schema")
for batch in reader:
iters = [iter(col) for col in batch.columns]
for _ in range(batch.num_rows):
name = next(iters[name_col]).as_py()
schema_buf = next(iters[schema_col]).as_buffer()
# Attempt to read schema. We can use a record batch reader
# for this - a serialised schema looks just like a record
# batch stream without any actual record batches.
call_schema = pyarrow.ipc.open_stream(
pyarrow.BufferReader(schema_buf)
).schema
procs[name] = call_schema
# Save
self._ns_table_meta[oid] = meta
self._ns_table_procs[oid] = procs
[docs] def reserve_namespace(
self,
name: str = None,
procs: List[pyarrow.Schema] = [],
prefix: bytes = b"",
) -> Tuple[bytes, plasma.PlasmaBuffer]:
"""Reserve a new namespace within the Plasma store
This will automatically clear all objects with the given prefix
:param name: Informative display name for namespace
:param procs: Call schemas supported (if any)
:param name: Metadata to associate with schema
:param prefix: Prefix for prefix
:returns: Prefix, buffer with declaration (to keep namespace alive)
"""
# Make set of used prefixes
self.update_obj_table()
used_prefixes = {
oid.binary()[: common.NAMESPACE_ID_SIZE]
for oid in self._obj_table.keys()
}
# Serialise namespace declaration
decl = common._make_namespace_decl(name, procs)
stream = pyarrow.BufferOutputStream()
writer = pyarrow.ipc.RecordBatchStreamWriter(stream, decl.schema)
writer.write(decl)
writer.close()
data = stream.getvalue().to_pybytes()
# Reserve a free namespace
for _prefix in common.objectid_generator(
prefix, common.NAMESPACE_ID_SIZE
):
if _prefix in used_prefixes:
continue
try:
# Attempt to write into Plasma
oid = plasma.ObjectID(next(common.objectid_generator(_prefix)))
self._client.create_and_seal(oid, data, b"")
# Immediately delete without releasing (-> normally
# will get deleted once client connection closes)
root = self.get_buffers([oid])[0]
self._client.delete([oid])
# Request deletion of all objects in new namespace
for oid in self._obj_table:
if oid.binary().startswith(_prefix):
self._client.delete(oid)
# Found our prefix!
return _prefix, root
except plasma.PlasmaObjectExists:
continue
# Nothing found!
raise RuntimeError("Could not find a free Plasma namespace!")
[docs] def get_ref_buffers(
self,
refs: List["Ref"],
timeout: float = None,
auto_delete: bool = True,
) -> None:
"""
Retrieves the buffers for multiple Plasma references at a time.
Blocks as long as any (!) of the objects have not been created yet.
:param refs: References to retrieve buffer of
:param timeout: Maximum time this function is allowed to block.
:param auto_delete: Delete object in store when reference is dropped
:raises: TimeoutException
"""
# Collect references not resolved yet
unresolved = [ref for ref in refs if ref._buf is None]
if len(unresolved) == 0:
return
# Wait for objects to become available
waiting = list(unresolved)
start = time.time()
while waiting:
# Remove available objects from waiting list
still_waiting = [
ref for ref in waiting if ref._oid not in self._obj_table
]
if len(still_waiting) == 0:
break
# Wait for an update on the object table
if timeout is not None:
dt = start + timeout - time.time()
if dt <= 0:
raise TimeoutException(still_waiting, timeout)
self.update_obj_table(dt)
else:
self.update_obj_table()
# Batch-request objects and make temporary
oids = [ref._oid for ref in unresolved]
bufs = self.get_buffers(oids, 0)
if auto_delete:
# As we are holding buffer references, this will only
# actually happen once no buffer reference exists any more
self._client.delete(oids)
for buf, ref in zip(bufs, unresolved):
ref._buf = buf
# Final check: It's possible that an object vanished in the
# mean time. We only do this after the above loop to make sure
# we retain as many references as possible.
for ref in unresolved:
if ref._buf is None:
raise RuntimeError(
f"Object {common.object_id_hex(ref._oid)} vanished "
"before it could be read!"
)
[docs]class Ref(metaclass=abc.ABCMeta):
"""Refers to an object in storage
Subclassed by type. Might not have been created yet. Can have two
kinds of relationships with other objects:
* dependency: Object is required for this object to be
created. Must ensure objects stay alive *until* this object is
found to be created. Typically refers to call objects.
* reference: Object that is referenced from this object and must
therefore be kept alive while this object is still needed.
"""
def __init__(
self,
conn: Connection,
oid: plasma.ObjectID,
auto_delete: bool = True,
dependencies: List["Ref"] = [],
references: List["Ref"] = [],
):
self._conn = conn
self._oid = oid
self._dependencies = []
self._references = []
for dep in dependencies:
self.add_dependency(dep)
for ref in references:
self.add_reference(ref)
self._auto_delete = auto_delete
self._buf = None
@property
def oid(self) -> plasma.ObjectID:
return self._oid
[docs] def add_dependency(self, ref: "Ref", timeout: float = 0) -> None:
"""Registers the identified object as a dependency.
This will ensure that the object is kept alive *until* we have
retrieved the data for this object. Blocks if the object does
not yet exist in the store.
:param ref: Reference to register as dependency
:param timeout: Maximum time this method is allowed to
block.
:raises: TimeoutException
"""
# Construct reference, if needed
if not isinstance(ref, Ref):
ref = Ref(self._conn, ref)
# Ensure buffer references store (might raise timeout exception)
self._conn.get_ref_buffers([self], timeout, self._auto_delete)
# Add dependency
self._dependencies.append(ref)
[docs] def add_reference(
self, ref: Union["Ref", plasma.ObjectID], timeout: float = 0
) -> None:
"""Registers the identified object as referenced
This will ensure that the object is kept alive as long as this
object is referenced. Might Block if the object does not yet
exist in the Plasma store.
:param ref: Reference to register as referenced
:param timeout: Maximum time this method is allowed to
block.
:raises: TimeoutException
"""
# Construct reference, if needed
if not isinstance(ref, Ref):
ref = Ref(self._conn, ref)
# Ensure buffer references store (might raise timeout exception)
self._conn.get_ref_buffers([self], timeout, self._auto_delete)
# Add dependency
self._references.append(ref)
[docs] def get_buffer(
self, timeout: float = None, auto_delete: bool = True
) -> pyarrow.Buffer:
"""
Get Arrow buffer for this Plasma reference.
Blocks if the object has not yet been created.
:param timeout: Maximum time this method is allowed to block.
:param auto_delete: Delete object in store when reference is dropped
:raises: TimeoutException
"""
# Return buffer, if present
if self._buf is not None:
return self._buf
# Retrieve buffer
self._conn.get_ref_buffers([self], timeout, self._auto_delete)
# Clear dependencies
self._dependencies = []
return self._buf
def __str__(self):
return f"{self.__class__.__name__}({common.object_id_hex(self._oid)})"
[docs]class TimeoutException(Exception):
def __init__(self, refs, timeout):
self._refs = refs
self._timeout = timeout
@property
def refs(self):
return self._refs
@property
def timeout(self):
return self._timeout
def __str__(self):
oids = [common.object_id_hex(ref.oid) for ref in self.refs]
if self._timeout > 0:
return (
f"Object{'s' if len(oids) > 1 else ''} {','.join(oids)} did "
f"not appear in Plasma store within {self._timeout}s!"
)
else:
return (
f"Object{'s' if len(oids) > 1 else ''} {','.join(oids)} do "
"not exist in Plasma store!"
)
[docs]class TensorRef(Ref):
"""Refers to a tensor in object storage.
Might not have been created yet - wraps an Object ID and expected
type information.
"""
def __init__(
self,
conn: Connection,
oid: plasma.ObjectID,
typ: pyarrow.DataType = None,
dim_names: List[str] = None,
auto_delete: bool = True,
):
self._typ = typ
self._dim_names = dim_names
super(TensorRef, self).__init__(conn, oid, auto_delete)
@property
def typ(self) -> pyarrow.DataType:
return self._typ
@property
def dim_names(self) -> List[str]:
return self._dim_names
[docs] def put(self, arr: numpy.ndarray = None):
"""Write the given value into storage.
:param arr: Array to write to storage. Empty by default.
"""
# If no array given, set an empty numpy array
if arr is None:
arr = numpy.empty(tuple([0] * len(self._dim_names)))
self._buf = common._put_numpy(
self._conn.client,
self._oid,
arr,
self._typ,
self._dim_names,
self._auto_delete,
)
[docs] def get(self, timeout=None):
"""Retrieve the tensor from storage. Might block.
:param timeout: Maximum time this method is allowed to block."""
# Read
return common._tensor_from_buf(
self.get_buffer(timeout),
common.object_id_hex(self._oid),
self._typ,
self._dim_names,
)
[docs]class TableRef(Ref):
"""Refers to a Table in Plasma store.
Might not have been created yet - wraps an Object ID and expected
type information.
"""
def __init__(
self,
conn: Connection,
oid: plasma.ObjectID,
schema: pyarrow.Schema = None,
auto_delete: bool = True,
):
self._schema = schema
self._table = None
super(TableRef, self).__init__(conn, oid, auto_delete)
@property
def schema(self) -> pyarrow.Schema:
return self._schema
[docs] def put(
self,
table: Union[
pyarrow.Table,
Mapping[str, pyarrow.ChunkedArray],
Mapping[str, pyarrow.Array],
Mapping[str, list],
],
max_chunksize=None,
):
"""Write the given table into storage.
The table can be given as pandas :py:class:`DataFrame`,
dictionary of strings to :py:class:`pyarrow.Array` or lists,
which will be converted into the equivalent table. If the
arrays are chunked, the chunks of all columns must match. See
also :py:meth:`pyarrow.table`.
:param table: Table to write.
:param max_chunksize: Maximum size of record batches to split table
into
"""
# A table?
if isinstance(table, pyarrow.Table):
# Check schema compatibility
if self._schema is not None:
problems = common.schema_compatible(self._schema, table.schema)
if problems:
raise TypeError(
"Wrong table schema: " + ", ".join(problems)
)
else:
# Convert into table
table = pyarrow.table(table, schema=self._schema)
# Get record batches, determine size sum
batches = table.to_batches(max_chunksize)
record_batch_size = sum(
map(pyarrow.ipc.get_record_batch_size, batches)
)
# Determine base size required for headers (unfortunately
# there doesn't seem to be a more direct way). Cache?
test_stream = pyarrow.BufferOutputStream()
pyarrow.RecordBatchStreamWriter(test_stream, self._schema).close()
table_size = test_stream.tell() + record_batch_size
# Reserve in Plasma
self._buf = self._conn.client.create(self._oid, table_size)
# Write out batches
buf_writer = pyarrow.FixedSizeBufferWriter(self._buf)
writer = pyarrow.RecordBatchStreamWriter(buf_writer, self._schema)
for batch in batches:
writer.write_batch(batch)
writer.close()
# Check that prediction was right - if this breaks we want to
# learn about it early.
if buf_writer.tell() != table_size:
logger.warning(
"Unexpected size of serialised table: "
f"{buf_writer.tell()} != {table_size}!"
)
# Seal + request deletion
self._conn.client.seal(self._oid)
self._conn.client.delete([self._oid])
[docs] def get(self, timeout=None):
"""Get Arrow table
:param timeout: How long the function is allowed to block if object
does not exist yet
"""
# Table already read?
if self._table is not None:
return self._table
# Read table from buffer
buf_reader = pyarrow.BufferReader(self.get_buffer(timeout))
reader = pyarrow.ipc.RecordBatchStreamReader(buf_reader)
self._table = reader.read_all()
return self._table
[docs] def get_awkward(self):
"""Get table as Awkward array
This can generally be done without copying the data.
"""
return awkward.from_arrow(self.get())
[docs] def get_pandas(self, *args, **kwargs):
"""Get table as Pandas dataframe
:param kwargs: Parameters to panda conversion.
See :py:func:`pyarrow.Table.to_pandas`.
"""
return self.get().to_pandas(*args, **kwargs)
[docs] def get_dict(self):
"""Get table as Python dictionary"""
return self.get().to_pydict()