Source code for send_to_squey

import copy
import datetime
from enum import StrEnum, auto
import json
import time

import pyarrow as pa
import pyarrow.flight as fl
import pyarrow.parquet as pq

from pip._vendor.rich.progress import Progress

import sys
if sys.version_info >= (3, 8):
    from importlib import metadata
else:
    import importlib_metadata as metadata

instance_started = False

class CloudProvider(StrEnum):
    AWS = auto()

class CloudInstance:
    cloud_provider = None
    instance_id = None
    profile_name = None
    region_name = None

    def __init__(self, cloud_provider, instance_id, profile_name, region_name):
        self.cloud_provider = cloud_provider
        self.instance_id = instance_id
        self.profile_name = profile_name
        self.region_name = region_name

[docs] class SqueyInstance: """ """ endpoint = None auth = None instance_id = None client = None options = None def __init__(self, endpoint, auth, disable_server_verification, compression_codec): self.endpoint = endpoint self.auth = auth self.client = None self.compression_codec = compression_codec self.version = metadata.version('send-to-squey') connect = lambda e, v : fl.connect(f"grpc+tls://{e}:5005", disable_server_verification=v) # Wait for Apache Arrow Flight service to be up and running if instance_started: bar = Progress() id = bar.add_task(description=f"Waiting for Squey service to start", total=None) bar.start() while True: try: self.client = connect(endpoint, disable_server_verification) try: list(self.client.list_flights()) except fl.FlightUnauthenticatedError as e: break except fl.FlightUnavailableError as e: time.sleep(10) bar.stop() else: self.client = connect(endpoint, disable_server_verification) # Authenticate against Apache Arrow Flight server token_pair = self.client.authenticate_basic_token(auth[0], auth[1]) self.options = pa.flight.FlightCallOptions( headers=[token_pair], write_options=pa.ipc.IpcWriteOptions(compression=compression_codec) )
[docs] def import_data(self, data, dataset_name=None): """ Uploads a dataset to the cloud Squey instance and imports it. :param data: any object compatible with pyarrow.record_batch data parameter \ (see https://arrow.apache.org/docs/python/generated/pyarrow.record_batch.html#pyarrow-record-batch) \ or the path toward a parquet file. :param dataset_name: the name of the dataset that will be displayed in Squey. :type dataset_name: string """ if dataset_name == None: dataset_name=datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') if isinstance(data, str): data = pq.ParquetFile(data) num_rows = data.metadata.num_rows else: num_rows = len(data) batch_size = max(int(num_rows / 100), 100) if isinstance(data, pq.ParquetFile): progress_total = num_rows record_batches = self._ParquetBatchReader(data, batch_size) elif hasattr(data, "iloc"): progress_total = num_rows split_func = lambda data, start, end : data.iloc[start:end] record_batches = self._IndexBasedBatchReader(data, batch_size, split_func) else: progress_total = None split_func = lambda data, start, end : data record_batches = self._IndexBasedBatchReader(data, num_rows, split_func) bar = Progress() description=f"Uploading '{dataset_name}' to Squey remote instance ({self.endpoint})" id = bar.add_task(description, total=progress_total) bar.start() descriptor_info = { "path": f"{dataset_name}.parquet", "compression": self.compression_codec, "version": self.version } descriptor = fl.FlightDescriptor.for_command(json.dumps(descriptor_info)) writer = None init = False for record_batch in record_batches: if not init: writer, _ = self.client.do_put( descriptor, record_batch.schema, options=self.options ) init = True writer.write_batch(record_batch) bar.advance(id, record_batch.num_rows) bar.stop()
class _DataBatchReader: def __init__(self, data_source, batch_size): self.data_source = data_source self.batch_size = batch_size def __iter__(self): return self def __next__(self): raise NotImplementedError("__next__ method not implemented.") class _IndexBasedBatchReader(_DataBatchReader): def __init__(self, data_source, batch_size, split_func): super().__init__(data_source, batch_size) self.num_rows = len(data_source) self.split_func = split_func self.current_index = 0 def __next__(self): if self.current_index >= self.num_rows: raise StopIteration start = self.current_index self.current_index = end = min(start + self.batch_size, self.num_rows) batch = self.split_func(self.data_source, start, end) return pa.record_batch(batch) class _ParquetBatchReader(_DataBatchReader): def __init__(self, data_source, batch_size): super().__init__(data_source, batch_size) self.parquet_file = data_source self.batch_iterator = self.parquet_file.iter_batches(batch_size=batch_size) def __next__(self): batch = next(self.batch_iterator) return batch
[docs] def start_instance( instance_id, cloud_provider=CloudProvider.AWS, profile_name=None, region_name=None ): """ Start a cloud instance. :param instance_id: the id of the instance :type instance_id: string :param cloud_provider: a value of the CloudProvider enum :type profile_name: enum :param profile_name: the profile name :type profile_name: string :param region_name: the region code :type region_name: string """ provider_module = importlib.import_module("squeylab.cloud_providers." + cloud_provider) if not provider_module.is_instance_running(instance_id, profile_name, region_name): bar = Progress() id = bar.add_task(description=f"Waiting for Squey instance to start", total=None) bar.start() global instance_started instance_started = True provider_module.start_instance(instance_id, profile_name, region_name) bar.stop()
[docs] def connect(endpoint, auth, disable_server_verification=False, compression_codec="lz4"): """ Connect to Squey server and returns a :class:`SqueyInstance` object. :param endpoint: the connection endpoint. :type endpoint: string :param auth: a tuple composed of username and password. :type auth: (string, string) :param port: the apache arrow flight server port. :type port: int :param disable_server_verification: disable SSL server verification. :type disable_server_verification: bool :param compression_codec: compression codec used ("lz4", "zstd" or None; defaults to "lz4") :type compression_codec: string :return: a :class:`SqueyInstance` object """ return SqueyInstance(endpoint, auth, disable_server_verification, compression_codec)