Source code for sertit.dask

import contextlib
import logging
import os
from contextlib import contextmanager

import psutil
import xarray as xr

from sertit import logs
from sertit.types import AnyXrDataStructure

LOGGER = logging.getLogger(logs.SU_NAME)

DEFAULT_CHUNKS = "auto"
""" Default chunks used in Sertit library (if dask is installed) """

SERTIT_DEFAULT_CHUNKS = "SERTIT_DEFAULT_CHUNKS"
"""
Environment variable to override default chunks.

Available keywords (case agnostic): 
- Give :code:`NONE` to set :code:`None`
- Give :code:`TRUE` to set :code:`True`
- Give :code:`AUTO` to set :code:`"auto"`
"""


[docs] def is_dask_installed(): try: from dask import optimize # noqa: F401 from dask.distributed import get_client # noqa: F401 return True except ModuleNotFoundError: # pragma: no cover return False
[docs] def get_client(): client = None if is_dask_installed(): from dask.distributed import get_client with contextlib.suppress(ValueError): # Return default client client = get_client() else: # pragma: no cover LOGGER.warning( "Can't import 'dask'. If you experiment out of memory issue, consider installing 'dask'." ) return client
[docs] @contextmanager def get_or_create_dask_client(processes=False, env_vars=None): """ Return default Dask client or create a local cluster and linked client if not existing Returns: """ client = None try: if is_dask_installed(): from dask.distributed import Client, get_client try: # Return default client client = get_client() except ValueError: if processes: # Gather information to create a client adapted to the computer ram_info = psutil.virtual_memory() available_ram = ram_info.available / 1024 / 1024 / 1024 available_ram = 0.9 * available_ram n_workers = 1 memory_limit = f"{available_ram}Gb" if available_ram >= 16: n_workers = available_ram // 16 memory_limit = f"{16}Gb" # Create a local cluster and return client LOGGER.warning( f"Init local cluster with {n_workers} workers and {memory_limit} per worker" ) client = Client( n_workers=int(n_workers), threads_per_worker=4, memory_limit=memory_limit, ) else: # Create a local cluster (threaded) LOGGER.warning("Init local cluster (threaded)") client = Client( processes=processes, ) if env_vars: client.run( lambda: os.environ.update({k: v for k, v in env_vars.items()}) ) yield client else: # pragma: no cover LOGGER.warning( "Can't import 'dask'. If you experiment out of memory issue, consider installing 'dask'." ) finally: try: if client is not None: client.close() except Exception as ex: # pragma: no cover LOGGER.warning(ex)
[docs] def get_dask_lock(name): """ Get a dask lock with given name. This lock is used in rioxarray.to_raster (see https://corteva.github.io/rioxarray/stable/examples/dask_read_write.html) If a Multiple worker client exists: returns a distributed.Lock() Elif a Multithreaded client exists: returns a threading.Lock() else: returns None Args: name: The name of the lock Returns: """ lock = None if is_dask_installed(): from dask.distributed import Lock current_client = get_client() if current_client: lock = Lock(name) else: # pragma: no cover LOGGER.warning( "Can't import 'dask'. If you experiment out of memory issue, consider installing 'dask'." ) return lock
[docs] def is_chunked(array: AnyXrDataStructure) -> bool: """ Returns true if the array is still chunked. (i.e. its data is not computed, bnot loaded into memory as a numpy array) Args: array (AnyXrDataStructure): Array to check Returns: True if array is still chunked """ try: if isinstance(array, xr.DataArray): is_chunked = array.chunks is not None else: is_chunked = len(array.chunks) > 0 except AttributeError: is_chunked = False return is_chunked
[docs] def is_computed(array: AnyXrDataStructure) -> bool: """ Returns true if the array is still chunked. (i.e. its data is not computed, bnot loaded into memory as a numpy array) Args: array (AnyXrDataStructure): Array to check Returns: True if array is still chunked """ return not is_chunked(array)
[docs] def get_default_chunks(): """ Get the default chunks: - check if dask is available - check :code:`SERTIT_DEFAULT_CHUNKS` env variable - defaults on DEFAULT_CHUNKS """ chunks = None if is_dask_installed(): chunks = os.getenv(SERTIT_DEFAULT_CHUNKS, DEFAULT_CHUNKS) if chunks.lower() == "none": chunks = None elif chunks.lower() == "auto": chunks = "auto" elif chunks.lower() == "true": chunks = True return chunks
# From xarray: https://github.com/pydata/xarray/blob/30743945538ca2d276fc28eb221afa7bcb03978a/xarray/tests/__init__.py#L236-L259 class _CountingScheduler: """Simple dask scheduler counting the number of computes. Reference: https://stackoverflow.com/questions/53289286/""" def __init__( self, max_computes=0, nof_computes=None, dont_raise=False, force_synchronous=False, ): self.total_computes = 0 self.nof_computes = nof_computes if max_computes is None or ( nof_computes is not None and max_computes < nof_computes ): max_computes = nof_computes self.max_computes = max_computes self.dont_raise = dont_raise self.debug = force_synchronous # In case of debug, use the dask basic scheduler if self.debug: import dask self.get = dask.get else: self.get = None def __call__(self, dsk, keys, **kwargs): self.total_computes += 1 # Log where the compute happens import traceback for tb in traceback.extract_stack()[::-1]: # Change this condition if we are using dask elsewhere than rasters.py and we want to display the tb if ( tb.filename.lower().startswith("/home/data") and "/rasters.py" in tb.filename ): # pragma: no cover LOGGER.debug( f"Computation number {self.total_computes}: {tb.line} | {tb.name} in {tb.filename} at line {tb.lineno}" ) break # Raise or warn if too many computes have occurred if self.total_computes > self.max_computes: text = f"Too many computes. Total: {self.total_computes} > max: {self.max_computes}." if self.dont_raise: LOGGER.warning(text) else: raise RuntimeError(text) # Use the wanted get if self.get is None: client = get_client() if client: self.get = client.get if self.get is not None: return self.get(dsk, keys, **kwargs) def check_total_nof_computes(self): if self.nof_computes is not None and self.total_computes != self.nof_computes: text = f"Unexpected number of computes. Total: {self.total_computes} != {self.nof_computes}." if self.dont_raise: # pragma: no cover LOGGER.warning(text) else: raise RuntimeError(text) return True
[docs] @contextmanager def raise_if_dask_computes( max_computes=0, nof_computes=None, dont_raise=False, force_synchronous=False ): # return a dummy context manager so that this can be used for non-dask objects if not is_dask_installed(): # pragma: no cover yield contextlib.nullcontext() import dask scheduler = _CountingScheduler( max_computes=max_computes, nof_computes=nof_computes, dont_raise=dont_raise, force_synchronous=force_synchronous, ) try: yield dask.config.set(scheduler=scheduler) finally: scheduler.check_total_nof_computes() # Make sure the counting scheduler is removed dask.config.set(scheduler=None)