2024-04-28 22:35:42 +02:00

224 lines
7.0 KiB
Python

from typing import Any, Callable, List, Union, Final
from time import perf_counter_ns
import numpy as np
from sys import stderr
import pickle
import os
from config import MODEL_DIR, OUT_DIR, __DEBUG
from decorators import njit
def formatted_row(gaps: list[int], titles: list[str], separator: str = '') -> None:
"""Print a formatted row of titles with of gaps seperated by a separator.
Args:
gaps: List of size gaps
titles: List of titles
separator: Separator character between each gap
"""
for gap, title in zip(gaps, titles):
print(f"{separator} {title:{'>' if gap < 0 else '<'}{abs(gap)}} ", end = '')
print(separator)
def formatted_line(gaps: list[int], left: str, middle: str, separator: str, right: str) -> None:
"""Print a formatted line of repeated characters.
Args:
gaps: List of size gaps
left: Character on the left
middle: Character between each separator
separator: Separator character between each gap
right: Character on the right
"""
print(left, end = '')
last_gap = len(gaps) - 1
for i, gap in enumerate(gaps):
print(f'{separator * (abs(gap) + 2)}', end = '')
if i != last_gap:
print(middle, end = '')
print(right)
def header(gaps: list[int], titles: list[str]) -> None:
"""Print a formatted header with the given titles and sizes.
Args:
gaps: List of size gaps
titles: List of titles
"""
formatted_line(gaps, '', '', '', '')
formatted_row(gaps, titles)
formatted_line(gaps, '', '', '', '')
def footer(gaps: list[int]) -> None:
"""Print a formatted footer with the given sizes.
Args:
gaps: List of size gaps
"""
formatted_line(gaps, '', '', '', '')
time_formats: Final = ['ns', 'µs', 'ms', 's', 'm', 'h', 'j', 'w', 'M', 'y', 'c']
time_numbers: Final = np.array([1, 1e3, 1e6, 1e9, 6e10, 36e11, 864e11, 6048e11, 26784e11, 31536e12, 31536e14], dtype = np.uint64)
@njit('str(uint64)')
def format_time_ns(time: int) -> str:
"""Format the time in nanoseconds in human readable format.
Args:
time (int): Time in nanoseconds
Returns:
str: The formatted human readable string
"""
assert time >= 0, 'Incorrect time stamp'
if time == 0:
return '0ns'
s = ''
for i in range(time_numbers.shape[0])[::-1]:
if time >= time_numbers[i]:
res = int(time // time_numbers[i])
time = time % time_numbers[i]
s += f'{res}{time_formats[i]} '
assert time == 0, 'Leftover in formatting time !'
return s.rstrip()
@njit('str(uint64)')
def format_time(time: int) -> str:
"""Format the time in seconds in human readable format.
Args:
time (int): Time in seconds
Returns:
str: The formatted human readable string
"""
assert time >= 0, 'Incorrect time stamp'
if time == 0:
return '0s'
s = ''
for i in range(3, time_numbers.shape[0])[::-1]:
time_number = time_numbers[i] / int(1e9)
if time >= time_number:
res = int(time // time_number)
time = time % time_number
s += f'{res}{time_formats[i]} '
assert time == 0, 'Leftover in formatting time !'
return s.rstrip()
def pickle_multi_loader(filenames: List[str], save_dir: str = MODEL_DIR) -> List[Any]:
"""Load multiple pickle data files.
Args:
filenames (List[str]): List of all the filename to load
save_dir (str, optional): Path of the files to load. Defaults to MODELS_DIR (see config.py)
Returns:
List[Any]. List of loaded pickle data files
"""
b = []
for f in filenames:
filepath = f'{save_dir}/{f}.pkl'
if os.path.exists(filepath):
with open(filepath, 'rb') as file_bytes:
b.append(pickle.load(file_bytes))
else:
b.append(None)
return b
def benchmark_function(step_name: str, column_width: int, fnc: Callable) -> Any:
"""Benchmark a function and display the result in stdout.
Args:
step_name (str): Name of the function to call
fnc (Callable): Function to call
Returns:
Any: Result of the function
"""
print(f'{step_name}...', file = stderr, end = '\r')
s = perf_counter_ns()
b = fnc()
e = perf_counter_ns() - s
print(f'{step_name:<{column_width}}{e:>18,}{format_time_ns(e):<29}')
return b
def state_saver(step_name: str, column_width: int, filename: Union[str, List[str]], fnc, force_redo: bool = False,
save_state: bool = True, save_dir: str = OUT_DIR) -> Any:
"""Either execute a function then saves the result or load the already existing result.
Args:
step_name (str): Name of the function to call
filename (Union[str, List[str]]): Name or list of names of the filenames where the result(s) are saved
fnc ([type]): Function to call
force_redo (bool, optional): Recall the function even if the result(s) is already saved. Defaults to False
save_dir (str, optional): Path of the directory to save the result(s). Defaults to OUT_DIR (see config.py)
Returns:
Any: The result(s) of the called function
"""
if isinstance(filename, str):
if not os.path.exists(f'{save_dir}/{filename}.pkl') or force_redo:
b = benchmark_function(step_name, column_width, fnc)
if save_state:
print(f'Saving results of {step_name}', file = stderr, end = '\r')
with open(f'{save_dir}/{filename}.pkl', 'wb') as f:
pickle.dump(b, f)
print(' ' * 100, file = stderr, end = '\r')
return b
else:
print(f'Loading results of {step_name}', file = stderr, end = '\r')
with open(f'{save_dir}/{filename}.pkl', 'rb') as f:
res = pickle.load(f)
print(f"{step_name:<{column_width}}{'None':>18}{'loaded saved state':<29}")
return res
elif isinstance(filename, list):
abs = False
for fn in filename:
if not os.path.exists(f'{save_dir}/{fn}.pkl'):
abs = True
break
if abs or force_redo:
b = benchmark_function(step_name, column_width, fnc)
if save_state:
print(f'Saving results of {step_name}', file = stderr, end = '\r')
for bi, fnI in zip(b, filename):
with open(f'{save_dir}/{fnI}.pkl', 'wb') as f:
pickle.dump(bi, f)
print(' ' * 100, file = stderr, end = '\r')
return b
print(f"{step_name:<{column_width}}{'None':>18}{'loaded saved state':<29}")
b = []
print(f'Loading results of {step_name}', file = stderr, end = '\r')
for fn in filename:
with open(f'{save_dir}/{fn}.pkl', 'rb') as f:
b.append(pickle.load(f))
print(' ' * 100, file = stderr, end = '\r')
return b
else:
assert False, f'Incompatible filename type = {type(filename)}'
@njit('boolean(int32[:, :], uint16[:, :])')
def unit_test_argsort_2d(arr: np.ndarray, indices: np.ndarray) -> bool:
"""Test if a given 2D array of indices sort a given 2D array.
Args:
arr (np.ndarray): 2D Array of data
indices (np.ndarray): 2D Indices that sort the array
Returns:
bool: Whether the test was successful
"""
n = indices.shape[0]
total = indices.shape[0] * indices.shape[1]
for i, sub_indices in enumerate(indices):
for j in range(sub_indices.shape[0] - 1):
if arr[i, sub_indices[j]] <= arr[i, sub_indices[j + 1]]:
n += 1
if __DEBUG:
if n != total:
print(n, total, n / (total))
return n == total