224 lines
7.0 KiB
Python
224 lines
7.0 KiB
Python
from typing import Any, Callable, List, Union, Final
|
|
from time import perf_counter_ns
|
|
import numpy as np
|
|
from sys import stderr
|
|
import pickle
|
|
import os
|
|
from config import MODEL_DIR, OUT_DIR, __DEBUG
|
|
from decorators import njit
|
|
|
|
def formatted_row(gaps: list[int], titles: list[str], separator: str = '│') -> None:
|
|
"""Print a formatted row of titles with of gaps seperated by a separator.
|
|
|
|
Args:
|
|
gaps: List of size gaps
|
|
titles: List of titles
|
|
separator: Separator character between each gap
|
|
"""
|
|
for gap, title in zip(gaps, titles):
|
|
print(f"{separator} {title:{'>' if gap < 0 else '<'}{abs(gap)}} ", end = '')
|
|
print(separator)
|
|
|
|
def formatted_line(gaps: list[int], left: str, middle: str, separator: str, right: str) -> None:
|
|
"""Print a formatted line of repeated characters.
|
|
|
|
Args:
|
|
gaps: List of size gaps
|
|
left: Character on the left
|
|
middle: Character between each separator
|
|
separator: Separator character between each gap
|
|
right: Character on the right
|
|
"""
|
|
print(left, end = '')
|
|
last_gap = len(gaps) - 1
|
|
for i, gap in enumerate(gaps):
|
|
print(f'{separator * (abs(gap) + 2)}', end = '')
|
|
if i != last_gap:
|
|
print(middle, end = '')
|
|
print(right)
|
|
|
|
def header(gaps: list[int], titles: list[str]) -> None:
|
|
"""Print a formatted header with the given titles and sizes.
|
|
|
|
Args:
|
|
gaps: List of size gaps
|
|
titles: List of titles
|
|
"""
|
|
formatted_line(gaps, '┌', '┬', '─', '┐')
|
|
formatted_row(gaps, titles)
|
|
formatted_line(gaps, '├', '┼', '─', '┤')
|
|
|
|
def footer(gaps: list[int]) -> None:
|
|
"""Print a formatted footer with the given sizes.
|
|
|
|
Args:
|
|
gaps: List of size gaps
|
|
"""
|
|
formatted_line(gaps, '└', '┴', '─', '┘')
|
|
|
|
time_formats: Final = ['ns', 'µs', 'ms', 's', 'm', 'h', 'j', 'w', 'M', 'y', 'c']
|
|
time_numbers: Final = np.array([1, 1e3, 1e6, 1e9, 6e10, 36e11, 864e11, 6048e11, 26784e11, 31536e12, 31536e14], dtype = np.uint64)
|
|
@njit('str(uint64)')
|
|
def format_time_ns(time: int) -> str:
|
|
"""Format the time in nanoseconds in human readable format.
|
|
|
|
Args:
|
|
time (int): Time in nanoseconds
|
|
|
|
Returns:
|
|
str: The formatted human readable string
|
|
"""
|
|
assert time >= 0, 'Incorrect time stamp'
|
|
if time == 0:
|
|
return '0ns'
|
|
|
|
s = ''
|
|
for i in range(time_numbers.shape[0])[::-1]:
|
|
if time >= time_numbers[i]:
|
|
res = int(time // time_numbers[i])
|
|
time = time % time_numbers[i]
|
|
s += f'{res}{time_formats[i]} '
|
|
|
|
assert time == 0, 'Leftover in formatting time !'
|
|
return s.rstrip()
|
|
|
|
@njit('str(uint64)')
|
|
def format_time(time: int) -> str:
|
|
"""Format the time in seconds in human readable format.
|
|
|
|
Args:
|
|
time (int): Time in seconds
|
|
|
|
Returns:
|
|
str: The formatted human readable string
|
|
"""
|
|
assert time >= 0, 'Incorrect time stamp'
|
|
if time == 0:
|
|
return '0s'
|
|
|
|
s = ''
|
|
for i in range(3, time_numbers.shape[0])[::-1]:
|
|
time_number = time_numbers[i] / int(1e9)
|
|
if time >= time_number:
|
|
res = int(time // time_number)
|
|
time = time % time_number
|
|
s += f'{res}{time_formats[i]} '
|
|
|
|
assert time == 0, 'Leftover in formatting time !'
|
|
return s.rstrip()
|
|
|
|
def pickle_multi_loader(filenames: List[str], save_dir: str = MODEL_DIR) -> List[Any]:
|
|
"""Load multiple pickle data files.
|
|
|
|
Args:
|
|
filenames (List[str]): List of all the filename to load
|
|
save_dir (str, optional): Path of the files to load. Defaults to MODELS_DIR (see config.py)
|
|
|
|
Returns:
|
|
List[Any]. List of loaded pickle data files
|
|
"""
|
|
b = []
|
|
for f in filenames:
|
|
filepath = f'{save_dir}/{f}.pkl'
|
|
if os.path.exists(filepath):
|
|
with open(filepath, 'rb') as file_bytes:
|
|
b.append(pickle.load(file_bytes))
|
|
else:
|
|
b.append(None)
|
|
return b
|
|
|
|
def benchmark_function(step_name: str, column_width: int, fnc: Callable) -> Any:
|
|
"""Benchmark a function and display the result in stdout.
|
|
|
|
Args:
|
|
step_name (str): Name of the function to call
|
|
fnc (Callable): Function to call
|
|
|
|
Returns:
|
|
Any: Result of the function
|
|
"""
|
|
print(f'{step_name}...', file = stderr, end = '\r')
|
|
s = perf_counter_ns()
|
|
b = fnc()
|
|
e = perf_counter_ns() - s
|
|
print(f'│ {step_name:<{column_width}} │ {e:>18,} │ {format_time_ns(e):<29} │')
|
|
return b
|
|
|
|
def state_saver(step_name: str, column_width: int, filename: Union[str, List[str]], fnc, force_redo: bool = False,
|
|
save_state: bool = True, save_dir: str = OUT_DIR) -> Any:
|
|
"""Either execute a function then saves the result or load the already existing result.
|
|
|
|
Args:
|
|
step_name (str): Name of the function to call
|
|
filename (Union[str, List[str]]): Name or list of names of the filenames where the result(s) are saved
|
|
fnc ([type]): Function to call
|
|
force_redo (bool, optional): Recall the function even if the result(s) is already saved. Defaults to False
|
|
save_dir (str, optional): Path of the directory to save the result(s). Defaults to OUT_DIR (see config.py)
|
|
|
|
Returns:
|
|
Any: The result(s) of the called function
|
|
"""
|
|
if isinstance(filename, str):
|
|
if not os.path.exists(f'{save_dir}/{filename}.pkl') or force_redo:
|
|
b = benchmark_function(step_name, column_width, fnc)
|
|
if save_state:
|
|
print(f'Saving results of {step_name}', file = stderr, end = '\r')
|
|
with open(f'{save_dir}/{filename}.pkl', 'wb') as f:
|
|
pickle.dump(b, f)
|
|
print(' ' * 100, file = stderr, end = '\r')
|
|
return b
|
|
else:
|
|
print(f'Loading results of {step_name}', file = stderr, end = '\r')
|
|
with open(f'{save_dir}/{filename}.pkl', 'rb') as f:
|
|
res = pickle.load(f)
|
|
print(f"│ {step_name:<{column_width}} │ {'None':>18} │ {'loaded saved state':<29} │")
|
|
return res
|
|
elif isinstance(filename, list):
|
|
abs = False
|
|
for fn in filename:
|
|
if not os.path.exists(f'{save_dir}/{fn}.pkl'):
|
|
abs = True
|
|
break
|
|
if abs or force_redo:
|
|
b = benchmark_function(step_name, column_width, fnc)
|
|
if save_state:
|
|
print(f'Saving results of {step_name}', file = stderr, end = '\r')
|
|
for bi, fnI in zip(b, filename):
|
|
with open(f'{save_dir}/{fnI}.pkl', 'wb') as f:
|
|
pickle.dump(bi, f)
|
|
print(' ' * 100, file = stderr, end = '\r')
|
|
return b
|
|
|
|
print(f"│ {step_name:<{column_width}} │ {'None':>18} │ {'loaded saved state':<29} │")
|
|
b = []
|
|
print(f'Loading results of {step_name}', file = stderr, end = '\r')
|
|
for fn in filename:
|
|
with open(f'{save_dir}/{fn}.pkl', 'rb') as f:
|
|
b.append(pickle.load(f))
|
|
print(' ' * 100, file = stderr, end = '\r')
|
|
return b
|
|
else:
|
|
assert False, f'Incompatible filename type = {type(filename)}'
|
|
|
|
@njit('boolean(int32[:, :], uint16[:, :])')
|
|
def unit_test_argsort_2d(arr: np.ndarray, indices: np.ndarray) -> bool:
|
|
"""Test if a given 2D array of indices sort a given 2D array.
|
|
|
|
Args:
|
|
arr (np.ndarray): 2D Array of data
|
|
indices (np.ndarray): 2D Indices that sort the array
|
|
|
|
Returns:
|
|
bool: Whether the test was successful
|
|
"""
|
|
n = indices.shape[0]
|
|
total = indices.shape[0] * indices.shape[1]
|
|
for i, sub_indices in enumerate(indices):
|
|
for j in range(sub_indices.shape[0] - 1):
|
|
if arr[i, sub_indices[j]] <= arr[i, sub_indices[j + 1]]:
|
|
n += 1
|
|
if __DEBUG:
|
|
if n != total:
|
|
print(n, total, n / (total))
|
|
return n == total
|