import abc
import logging
from typing import Any, Dict, Iterator, List
import numpy
import pyarrow
import pyarrow.plasma as plasma
from . import common, connection, store
logger = logging.getLogger(__name__)
[docs]class Caller(metaclass=abc.ABCMeta):
"""Base class for calls to a :py:class:`.processor.Processor` class
The constructor will create methods according to the passed call
schemas - both for single and for batch calls. The batch variant
will expect a list of dictionaries, see :py:func:`batch_call`.
:param procs: Call schemas to support. Will be used to find a
compatible processor.
:param store: Store area to use for calls (will use its Plasma client)
:param broadcast: Send calls to all matching processors?
:param minimum_processors: Raise an error if fewer processors are available
:param processor_prefix: Allow changing processors after initialisation?
:param max_attempts: Maximum attempts at resolving ObjectID collisions
:param verbose: Log information about found processors
"""
def __init__(
self,
procs: List[pyarrow.Schema],
store: store.Store,
broadcast: bool = False,
minimum_processors: int = 1,
processor_prefix: bytes = b"",
max_attempts: int = 100,
verbose: bool = False,
):
# Initialise
self._store = store
self._conn = store.conn
self._procs = procs
self._broadcast = broadcast
self._minimum_processors = minimum_processors
self._processor_prefix = processor_prefix
self._max_attempts = max_attempts
# Find processors
self._processors = {}
self._known_processors = set()
self.find_processors(verbose)
for proc in procs:
self._register_call(proc)
@property
def num_processors(self):
"""The number of processors located by this caller"""
return len(self._processors)
[docs] def find_processors(self, verbose=False):
"""Locate compatible processors.
Done automatically when the caller is constructed. Call again
to refresh the list of processors to call. Typically used with
broadcasting callers.
:param verbose: Log information about found processors
"""
# Refrech object and namespace table
self._conn.update_obj_table(0)
# Get namespace table from store, clear all processors that vanished
namespace_table = self._conn.namespace_procs
for oid in set(self._processors.keys()) - set(namespace_table.keys()):
if verbose:
logger.info(
"Processor at prefix %s vanished",
common.object_id_hex(oid),
)
del self._processors[oid]
for oid in self._known_processors - set(namespace_table.keys()):
self._known_processors.remove(oid)
# Find compatible processors
for processor_oid, procs in namespace_table.items():
if (
processor_oid in self._processors
or processor_oid in self._known_processors
):
continue
# Ignore processors with wrong prefix
if not processor_oid.binary().startswith(self._processor_prefix):
continue
processor_meta = self._conn.namespace_meta[processor_oid]
processor_name = str(processor_oid)
if common.PROC_NAMESPACE_META in processor_meta:
processor_name = processor_meta[
common.PROC_NAMESPACE_META
].decode()
# Check all our procs, make sure we have a compatible target
procs_found = 0
schema_matches = 0
for expected_schema in self._procs:
name = common.call_name(expected_schema)
if name in procs:
procs_found += 1
provided_schema = procs[name]
# Check for schema compatibility
problems = common.schema_compatible(
expected_schema, provided_schema
)
if problems:
logger.info(
"Processor '%s' has incompatible call "
+ "schema for '%s': ",
processor_name,
name,
", ".join(problems),
)
else:
schema_matches += 1
# Found?
if schema_matches == len(self._procs):
prefix = processor_oid.binary()[: common.NAMESPACE_ID_SIZE]
if verbose:
logger.info(
"Found processor %s at prefix %s",
processor_name,
common.object_id_hex(prefix),
)
self._processors[processor_oid] = common.objectid_generator(
prefix
)
else:
self._known_processors.add(processor_oid)
# Found enough processors?
if len(self._processors) < self._minimum_processors:
call_names = ", ".join(
[f"'{common.call_name(call)}'" for call in self._procs]
)
if self._minimum_processors == 1:
raise RuntimeError(
f"Failed to find processor compatible with call schemas "
f"for {call_names}!"
)
else:
raise RuntimeError(
f"Failed to find {self._minimum_processors} processors "
+ f"compatible with call schemas for {call_names}!"
)
# Broadcast? Otherwise warn if there is more than one
# processor. Might need some way to choose between them...
if not self._broadcast and len(self._processors) > 1:
call_names = ", ".join(
[f"'{common.call_name(call)}'" for call in self._procs]
)
if verbose:
logger.warning(
"Multiple processors accept calls for %s!", call_names
)
while len(self._processors) > 1:
del self._processors[next(iter(self._processors))]
return self.num_processors
[docs] def batch_call(
self, call_schema: pyarrow.Schema, calls: List[Dict[str, Any]]
) -> List[Dict[str, connection.TensorRef]]:
"""Create a number of calls to a function with the given schema.
:param call_schema: Schema of the call
:param calls: List of parameter dictionaries
:returns: List of output parameter dictionaries per call
(if broadcasting also per processor)
"""
# Get parameter kinds
output_pars = list(
[
call_schema.field(i)
for i in range(len(call_schema.names))
if common.par_meta(call_schema.field(i))
== common.PROC_PAR_META_OUT
]
)
input_pars = list(
[
call_schema.field(i)
for i in range(len(call_schema.names))
if common.par_meta(call_schema.field(i))
== common.PROC_PAR_META_IN
]
)
# Put input tensors where required
in_calls = list([dict(call) for call in calls])
for in_field in input_pars:
tensor_type = common.par_tensor_elem_type(in_field)
dim_names = common.par_tensor_dim_names(in_field)
table_schema = common.par_table_schema(in_field)
for call in in_calls:
if tensor_type is not None:
# Already appropriate type?
v = call.get(in_field.name)
if v is None or isinstance(v, connection.TensorRef):
continue
# Otherwise try to convert to ndarray and write to store
if not isinstance(v, numpy.ndarray):
v = numpy.array(v, dtype=tensor_type.to_pandas_dtype())
call[in_field.name] = self._store.put_new_tensor(
v, tensor_type, dim_names
)
if table_schema is not None:
# Already a table ref?
v = call.get(in_field.name)
if v is None or isinstance(v, connection.TableRef):
continue
# Create in store
call[in_field.name] = self._store.put_new_table(
v, table_schema
)
# Possibly broadcast to processors
outputs = []
for oid_gen in self._processors.values():
# Add output parameters, if missing
proc_calls = list([dict(call) for call in in_calls])
for out_field in output_pars:
tensor_type = common.par_tensor_elem_type(out_field)
dim_names = common.par_tensor_dim_names(out_field)
table_schema = common.par_table_schema(out_field)
for call in proc_calls:
if out_field.name not in call:
if tensor_type is not None:
call[out_field.name] = self._store.new_tensor_ref(
tensor_type, dim_names
)
if table_schema is not None:
call[out_field.name] = self._store.new_table_ref(
table_schema
)
# Make call
self._batch_call(call_schema, proc_calls, oid_gen)
# Collect outputs
output = [
{
out_field.name: call[out_field.name]
for out_field in output_pars
}
for call in proc_calls
]
if not self._broadcast:
return output
outputs.append(output)
return outputs
def _batch_call(
self,
call_schema: pyarrow.Schema,
calls: List[Dict[str, Any]],
oid_gen: Iterator[plasma.ObjectID],
) -> plasma.ObjectID:
# Collect parameters as Python dictionaries (transpose,
# effectively). We assume names are in column order.
cols = [[] for _ in call_schema.names]
fields = [call_schema.field(i) for i in range(len(cols))]
in_refs = []
out_refs = []
for pars in calls:
# Gather parameters for call
used = 0
for i, field in enumerate(fields):
# If parameter is missing, it must be nullable
val = pars.get(field.name)
if val is None:
if not field.nullable:
raise ValueError(
(
"Parameter {} is missing, "
+ "but not nullable!"
).format(field.name)
)
else:
# Convert references
if isinstance(val, connection.Ref):
par_meta = common.par_meta(field)
if par_meta == common.PROC_PAR_META_IN:
in_refs.append(val)
elif par_meta == common.PROC_PAR_META_OUT:
out_refs.append(val)
else:
raise ValueError(
"Unexpected reference for parameter "
f"{field.name}!"
)
val = val.oid.binary()
# Set
cols[i].append(val)
used += 1
# Make sure we have used all passed parameters
unused = [
f"'{par}'" for par in pars if par not in call_schema.names
]
if unused:
raise ValueError(
f"Parameters {', '.join(unused)} not used in call!"
)
# Convert to pyarrays of appropriate type, create record batch
arrays = [
pyarrow.array(col, typ)
for col, typ in zip(cols, call_schema.types)
]
batch = pyarrow.record_batch(arrays, call_schema)
# Serialise
sink = pyarrow.BufferOutputStream()
writer = pyarrow.RecordBatchStreamWriter(sink, call_schema)
writer.write_batch(batch)
writer.close()
# Create Plasma object(s) for call
success = False
call_bufs = []
for _ in range(self._max_attempts):
try:
# Get actual Plasma buffer, copy
oid = plasma.ObjectID(next(oid_gen))
self._store._conn._client.create_and_seal(
oid, sink.getvalue().to_pybytes()
)
# Get+delete
buf = self._store._conn.get_buffers([oid])[0]
self._store._conn._client.delete([oid])
# Done
call_bufs.append(buf)
success = True
break
except plasma.PlasmaObjectExists:
pass
if not success:
raise RuntimeError(
"Maximum number of retries reached while "
+ "attempting to find unused object ID for call to "
+ f"'{call_schema.metadata[common.PROC_FUNC_META]}'!"
)
# Register inputs as dependencies of outputs
for out_ref in out_refs:
out_ref._dependencies = in_refs + call_bufs
return oid
[docs] def call(
self, call_schema: pyarrow.Schema, *args, **kwargs
) -> plasma.ObjectID:
"""Create a number of calls to a function with the given schema.
Both positioned and keyword arrays are supported, using the
position and name of the parameter in the schema, respectively.
:param call_schema: Schema of the call
:param args: List of parameters
:param kwargs: Dictionary of parameters
"""
for name, arg in zip(call_schema.names, args):
if name in kwargs:
raise TypeError(
f"{common.call_name(call_schema)} got multiple "
f"values for argument '{name}'!"
)
kwargs[name] = arg
return self.batch_call(call_schema, [kwargs])
def _register_call(self, call_schema: pyarrow.Schema):
"""Creates a method for calling function on this class.
Meant to be called from constructors of derived functions to
create handy wrappers for making calls. Constructs a call
schema with the given parameters behind the scenes.
:param call_schema: Schema of call to create
"""
# Determine function names, check that we don't overwrite something
func_name = call_schema.metadata[common.PROC_FUNC_META].decode()
batch_func_name = func_name + "_batch"
assert not hasattr(self, func_name)
assert not hasattr(self, batch_func_name)
# Generate and set functions
def call_func(*args, **kwargs):
batch_results = self.call(call_schema, *args, **kwargs)
if self._broadcast:
return [ret[0] for ret in batch_results]
else:
return batch_results[0]
def call_func_batch(par_vals):
return self.batch_call(call_schema, par_vals)
setattr(self, func_name, call_func)
setattr(self, batch_func_name, call_func_batch)