Source code for ska_sdp_dal.caller

import abc
import logging
from typing import Any, Dict, Iterator, List

import numpy
import pyarrow
import pyarrow.plasma as plasma

from . import common, connection, store

logger = logging.getLogger(__name__)


[docs]class Caller(metaclass=abc.ABCMeta): """Base class for calls to a :py:class:`.processor.Processor` class The constructor will create methods according to the passed call schemas - both for single and for batch calls. The batch variant will expect a list of dictionaries, see :py:func:`batch_call`. :param procs: Call schemas to support. Will be used to find a compatible processor. :param store: Store area to use for calls (will use its Plasma client) :param broadcast: Send calls to all matching processors? :param minimum_processors: Raise an error if fewer processors are available :param processor_prefix: Allow changing processors after initialisation? :param max_attempts: Maximum attempts at resolving ObjectID collisions :param verbose: Log information about found processors """ def __init__( self, procs: List[pyarrow.Schema], store: store.Store, broadcast: bool = False, minimum_processors: int = 1, processor_prefix: bytes = b"", max_attempts: int = 100, verbose: bool = False, ): # Initialise self._store = store self._conn = store.conn self._procs = procs self._broadcast = broadcast self._minimum_processors = minimum_processors self._processor_prefix = processor_prefix self._max_attempts = max_attempts # Find processors self._processors = {} self._known_processors = set() self.find_processors(verbose) for proc in procs: self._register_call(proc) @property def num_processors(self): """The number of processors located by this caller""" return len(self._processors)
[docs] def find_processors(self, verbose=False): """Locate compatible processors. Done automatically when the caller is constructed. Call again to refresh the list of processors to call. Typically used with broadcasting callers. :param verbose: Log information about found processors """ # Refrech object and namespace table self._conn.update_obj_table(0) # Get namespace table from store, clear all processors that vanished namespace_table = self._conn.namespace_procs for oid in set(self._processors.keys()) - set(namespace_table.keys()): if verbose: logger.info( "Processor at prefix %s vanished", common.object_id_hex(oid), ) del self._processors[oid] for oid in self._known_processors - set(namespace_table.keys()): self._known_processors.remove(oid) # Find compatible processors for processor_oid, procs in namespace_table.items(): if ( processor_oid in self._processors or processor_oid in self._known_processors ): continue # Ignore processors with wrong prefix if not processor_oid.binary().startswith(self._processor_prefix): continue processor_meta = self._conn.namespace_meta[processor_oid] processor_name = str(processor_oid) if common.PROC_NAMESPACE_META in processor_meta: processor_name = processor_meta[ common.PROC_NAMESPACE_META ].decode() # Check all our procs, make sure we have a compatible target procs_found = 0 schema_matches = 0 for expected_schema in self._procs: name = common.call_name(expected_schema) if name in procs: procs_found += 1 provided_schema = procs[name] # Check for schema compatibility problems = common.schema_compatible( expected_schema, provided_schema ) if problems: logger.info( "Processor '%s' has incompatible call " + "schema for '%s': ", processor_name, name, ", ".join(problems), ) else: schema_matches += 1 # Found? if schema_matches == len(self._procs): prefix = processor_oid.binary()[: common.NAMESPACE_ID_SIZE] if verbose: logger.info( "Found processor %s at prefix %s", processor_name, common.object_id_hex(prefix), ) self._processors[processor_oid] = common.objectid_generator( prefix ) else: self._known_processors.add(processor_oid) # Found enough processors? if len(self._processors) < self._minimum_processors: call_names = ", ".join( [f"'{common.call_name(call)}'" for call in self._procs] ) if self._minimum_processors == 1: raise RuntimeError( f"Failed to find processor compatible with call schemas " f"for {call_names}!" ) else: raise RuntimeError( f"Failed to find {self._minimum_processors} processors " + f"compatible with call schemas for {call_names}!" ) # Broadcast? Otherwise warn if there is more than one # processor. Might need some way to choose between them... if not self._broadcast and len(self._processors) > 1: call_names = ", ".join( [f"'{common.call_name(call)}'" for call in self._procs] ) if verbose: logger.warning( "Multiple processors accept calls for %s!", call_names ) while len(self._processors) > 1: del self._processors[next(iter(self._processors))] return self.num_processors
[docs] def batch_call( self, call_schema: pyarrow.Schema, calls: List[Dict[str, Any]] ) -> List[Dict[str, connection.TensorRef]]: """Create a number of calls to a function with the given schema. :param call_schema: Schema of the call :param calls: List of parameter dictionaries :returns: List of output parameter dictionaries per call (if broadcasting also per processor) """ # Get parameter kinds output_pars = list( [ call_schema.field(i) for i in range(len(call_schema.names)) if common.par_meta(call_schema.field(i)) == common.PROC_PAR_META_OUT ] ) input_pars = list( [ call_schema.field(i) for i in range(len(call_schema.names)) if common.par_meta(call_schema.field(i)) == common.PROC_PAR_META_IN ] ) # Put input tensors where required in_calls = list([dict(call) for call in calls]) for in_field in input_pars: tensor_type = common.par_tensor_elem_type(in_field) dim_names = common.par_tensor_dim_names(in_field) table_schema = common.par_table_schema(in_field) for call in in_calls: if tensor_type is not None: # Already appropriate type? v = call.get(in_field.name) if v is None or isinstance(v, connection.TensorRef): continue # Otherwise try to convert to ndarray and write to store if not isinstance(v, numpy.ndarray): v = numpy.array(v, dtype=tensor_type.to_pandas_dtype()) call[in_field.name] = self._store.put_new_tensor( v, tensor_type, dim_names ) if table_schema is not None: # Already a table ref? v = call.get(in_field.name) if v is None or isinstance(v, connection.TableRef): continue # Create in store call[in_field.name] = self._store.put_new_table( v, table_schema ) # Possibly broadcast to processors outputs = [] for oid_gen in self._processors.values(): # Add output parameters, if missing proc_calls = list([dict(call) for call in in_calls]) for out_field in output_pars: tensor_type = common.par_tensor_elem_type(out_field) dim_names = common.par_tensor_dim_names(out_field) table_schema = common.par_table_schema(out_field) for call in proc_calls: if out_field.name not in call: if tensor_type is not None: call[out_field.name] = self._store.new_tensor_ref( tensor_type, dim_names ) if table_schema is not None: call[out_field.name] = self._store.new_table_ref( table_schema ) # Make call self._batch_call(call_schema, proc_calls, oid_gen) # Collect outputs output = [ { out_field.name: call[out_field.name] for out_field in output_pars } for call in proc_calls ] if not self._broadcast: return output outputs.append(output) return outputs
def _batch_call( self, call_schema: pyarrow.Schema, calls: List[Dict[str, Any]], oid_gen: Iterator[plasma.ObjectID], ) -> plasma.ObjectID: # Collect parameters as Python dictionaries (transpose, # effectively). We assume names are in column order. cols = [[] for _ in call_schema.names] fields = [call_schema.field(i) for i in range(len(cols))] in_refs = [] out_refs = [] for pars in calls: # Gather parameters for call used = 0 for i, field in enumerate(fields): # If parameter is missing, it must be nullable val = pars.get(field.name) if val is None: if not field.nullable: raise ValueError( ( "Parameter {} is missing, " + "but not nullable!" ).format(field.name) ) else: # Convert references if isinstance(val, connection.Ref): par_meta = common.par_meta(field) if par_meta == common.PROC_PAR_META_IN: in_refs.append(val) elif par_meta == common.PROC_PAR_META_OUT: out_refs.append(val) else: raise ValueError( "Unexpected reference for parameter " f"{field.name}!" ) val = val.oid.binary() # Set cols[i].append(val) used += 1 # Make sure we have used all passed parameters unused = [ f"'{par}'" for par in pars if par not in call_schema.names ] if unused: raise ValueError( f"Parameters {', '.join(unused)} not used in call!" ) # Convert to pyarrays of appropriate type, create record batch arrays = [ pyarrow.array(col, typ) for col, typ in zip(cols, call_schema.types) ] batch = pyarrow.record_batch(arrays, call_schema) # Serialise sink = pyarrow.BufferOutputStream() writer = pyarrow.RecordBatchStreamWriter(sink, call_schema) writer.write_batch(batch) writer.close() # Create Plasma object(s) for call success = False call_bufs = [] for _ in range(self._max_attempts): try: # Get actual Plasma buffer, copy oid = plasma.ObjectID(next(oid_gen)) self._store._conn._client.create_and_seal( oid, sink.getvalue().to_pybytes() ) # Get+delete buf = self._store._conn.get_buffers([oid])[0] self._store._conn._client.delete([oid]) # Done call_bufs.append(buf) success = True break except plasma.PlasmaObjectExists: pass if not success: raise RuntimeError( "Maximum number of retries reached while " + "attempting to find unused object ID for call to " + f"'{call_schema.metadata[common.PROC_FUNC_META]}'!" ) # Register inputs as dependencies of outputs for out_ref in out_refs: out_ref._dependencies = in_refs + call_bufs return oid
[docs] def call( self, call_schema: pyarrow.Schema, *args, **kwargs ) -> plasma.ObjectID: """Create a number of calls to a function with the given schema. Both positioned and keyword arrays are supported, using the position and name of the parameter in the schema, respectively. :param call_schema: Schema of the call :param args: List of parameters :param kwargs: Dictionary of parameters """ for name, arg in zip(call_schema.names, args): if name in kwargs: raise TypeError( f"{common.call_name(call_schema)} got multiple " f"values for argument '{name}'!" ) kwargs[name] = arg return self.batch_call(call_schema, [kwargs])
def _register_call(self, call_schema: pyarrow.Schema): """Creates a method for calling function on this class. Meant to be called from constructors of derived functions to create handy wrappers for making calls. Constructs a call schema with the given parameters behind the scenes. :param call_schema: Schema of call to create """ # Determine function names, check that we don't overwrite something func_name = call_schema.metadata[common.PROC_FUNC_META].decode() batch_func_name = func_name + "_batch" assert not hasattr(self, func_name) assert not hasattr(self, batch_func_name) # Generate and set functions def call_func(*args, **kwargs): batch_results = self.call(call_schema, *args, **kwargs) if self._broadcast: return [ret[0] for ret in batch_results] else: return batch_results[0] def call_func_batch(par_vals): return self.batch_call(call_schema, par_vals) setattr(self, func_name, call_func) setattr(self, batch_func_name, call_func_batch)