from typing import Any, Callable, List, Union, Final from time import perf_counter_ns import numpy as np from sys import stderr import pickle import os from config import MODEL_DIR, OUT_DIR from decorators import njit def formatted_row(gaps: list[int], titles: list[str], separator: str = '│') -> None: """Print a formatted row of titles with of gaps seperated by a separator. Args: gaps: List of size gaps titles: List of titles separator: Separator character between each gap """ for gap, title in zip(gaps, titles): print(f"{separator} {title:{'>' if gap < 0 else '<'}{abs(gap)}} ", end = '') print(separator) def formatted_line(gaps: list[int], left: str, middle: str, separator: str, right: str) -> None: """Print a formatted line of repeated characters. Args: gaps: List of size gaps left: Character on the left middle: Character between each separator separator: Separator character between each gap right: Character on the right """ print(left, end = '') last_gap = len(gaps) - 1 for i, gap in enumerate(gaps): print(f'{separator * (abs(gap) + 2)}', end = '') if i != last_gap: print(middle, end = '') print(right) def header(gaps: list[int], titles: list[str]) -> None: """Print a formatted header with the given titles and sizes. Args: gaps: List of size gaps titles: List of titles """ formatted_line(gaps, '┌', '┬', '─', '┐') formatted_row(gaps, titles) formatted_line(gaps, '├', '┼', '─', '┤') def footer(gaps: list[int]) -> None: """Print a formatted fooder with the given sizes Args: gaps: List of size gaps """ formatted_line(gaps, '└', '┴', '─', '┘') time_formats: Final = ['ns', 'µs', 'ms', 's', 'm', 'h', 'j', 'w', 'M', 'y', 'c'] time_numbers: Final = np.array([1, 1e3, 1e6, 1e9, 6e10, 36e11, 864e11, 6048e11, 26784e11, 31536e12, 31536e14], dtype = np.uint64) @njit('str(uint64)') def format_time_ns(time: int) -> str: """Format the time in nanoseconds in human readable format. Args: time (int): Time in nanoseconds Returns: str: The formatted human readable string """ assert time >= 0, 'Incorrect time stamp' if time == 0: return '0ns' s = '' for i in range(time_numbers.shape[0])[::-1]: if time >= time_numbers[i]: res = int(time // time_numbers[i]) time = time % time_numbers[i] s += f'{res}{time_formats[i]} ' assert time == 0, 'Leftover in formatting time !' return s.rstrip() @njit('str(uint64)') def format_time(time: int) -> str: """Format the time in seconds in human readable format. Args: time (int): Time in seconds Returns: str: The formatted human readable string """ assert time >= 0, 'Incorrect time stamp' if time == 0: return '0s' s = '' for i in range(3, time_numbers.shape[0])[::-1]: time_number = time_numbers[i] / int(1e9) if time >= time_number: res = int(time // time_number) time = time % time_number s += f'{res}{time_formats[i]} ' assert time == 0, 'Leftover in formatting time !' return s.rstrip() def pickle_multi_loader(filenames: List[str], save_dir: str = MODEL_DIR) -> List[Any]: """Load multiple pickle data files. Args: filenames (List[str]): List of all the filename to load save_dir (str, optional): Path of the files to load. Defaults to MODELS_DIR (see config.py) Returns: List[Any]. List of loaded pickle data files """ b = [] for f in filenames: filepath = f'{save_dir}/{f}.pkl' if os.path.exists(filepath): with open(filepath, 'rb') as file_bytes: b.append(pickle.load(file_bytes)) else: b.append(None) return b def benchmark_function(step_name: str, column_width: int, fnc: Callable) -> Any: """Benchmark a function and display the result of stdout. Args: step_name (str): Name of the function to call fnc (Callable): Function to call Returns: Any: Result of the function """ print(f'{step_name}...', file = stderr, end = '\r') s = perf_counter_ns() b = fnc() e = perf_counter_ns() - s print(f'│ {step_name:<{column_width}} │ {e:>18,} │ {format_time_ns(e):<29} │') return b def state_saver(step_name: str, column_width: int, filename: Union[str, List[str]], fnc, force_redo: bool = False, save_state: bool = True, save_dir: str = OUT_DIR) -> Any: """Either execute a function then saves the result or load the already existing result. Args: step_name (str): Name of the function to call filename (Union[str, List[str]]): Name or list of names of the filenames where the result(s) are saved fnc ([type]): Function to call force_redo (bool, optional): Recall the function even if the result(s) is already saved. Defaults to False save_dir (str, optional): Path of the directory to save the result(s). Defaults to OUT_DIR (see config.py) Returns: Any: The result(s) of the called function """ if isinstance(filename, str): if not os.path.exists(f'{save_dir}/{filename}.pkl') or force_redo: b = benchmark_function(step_name, column_width, fnc) if save_state: print(f'Saving results of {step_name}', file = stderr, end = '\r') with open(f'{save_dir}/{filename}.pkl', 'wb') as f: pickle.dump(b, f) print(' ' * 100, file = stderr, end = '\r') return b else: print(f'Loading results of {step_name}', file = stderr, end = '\r') with open(f'{save_dir}/{filename}.pkl', 'rb') as f: res = pickle.load(f) print(f"│ {step_name:<{column_width}} │ {'None':>18} │ {'loaded saved state':<29} │") return res elif isinstance(filename, list): abs = False for fn in filename: if not os.path.exists(f'{save_dir}/{fn}.pkl'): abs = True break if abs or force_redo: b = benchmark_function(step_name, column_width, fnc) if save_state: print(f'Saving results of {step_name}', file = stderr, end = '\r') for bi, fnI in zip(b, filename): with open(f'{save_dir}/{fnI}.pkl', 'wb') as f: pickle.dump(bi, f) print(' ' * 100, file = stderr, end = '\r') return b print(f"│ {step_name:<{column_width}} │ {'None':>18} │ {'loaded saved state':<29} │") b = [] print(f'Loading results of {step_name}', file = stderr, end = '\r') for fn in filename: with open(f'{save_dir}/{fn}.pkl', 'rb') as f: b.append(pickle.load(f)) print(' ' * 100, file = stderr, end = '\r') return b else: assert False, f'Incompatible filename type = {type(filename)}' @njit('boolean(int32[:, :], uint16[:, :])') def unit_test_argsort_2d(arr: np.ndarray, indices: np.ndarray) -> bool: """Test if a given array of indices sort a given array. Args: arr (np.ndarray): Array of data indices (np.ndarray): Indices that sort arr Returns: bool: Success of the test """ n = indices.shape[0] total = indices.shape[0] * indices.shape[1] for i, sub_indices in enumerate(indices): for j in range(sub_indices.shape[0] - 1): if arr[i, sub_indices[j]] <= arr[i, sub_indices[j + 1]]: n += 1 if n != total: print(n, total, n / (total)) return n == total