139 lines
4.5 KiB
Python
139 lines
4.5 KiB
Python
from typing import Any, Callable, List, Union, Final
|
|
from time import perf_counter_ns
|
|
from numba import njit
|
|
import numpy as np
|
|
import pickle
|
|
import os
|
|
from config import MODEL_DIR, OUT_DIR
|
|
from decorators import njit
|
|
|
|
time_formats: Final = ["ns", "µs", "ms", "s", "m", "h", "j", "w", "M", "y", "c"]
|
|
time_numbers: Final = np.array([1, 1e3, 1e6, 1e9, 6e10, 36e11, 864e11, 6048e11, 26784e11, 31536e12, 31536e14], dtype = np.uint64)
|
|
@njit('str(uint64)')
|
|
def format_time_ns(time: int) -> str:
|
|
"""Format the time in nanoseconds in human readable format.
|
|
|
|
Args:
|
|
time (int): Time in nanoseconds.
|
|
|
|
Returns:
|
|
str: The formatted human readable string.
|
|
"""
|
|
assert time >= 0, "Incorrect time stamp"
|
|
if time == 0:
|
|
return "0ns"
|
|
|
|
s = ""
|
|
for i in range(time_numbers.shape[0])[::-1]:
|
|
if time >= time_numbers[i]:
|
|
res = int(time // time_numbers[i])
|
|
time = time % time_numbers[i]
|
|
s += f"{res}{time_formats[i]} "
|
|
|
|
assert time == 0, "Leftover in formatting time !"
|
|
return s.rstrip()
|
|
|
|
def picke_multi_loader(filenames: List[str], save_dir: str = MODEL_DIR) -> List[Any]:
|
|
"""Load multiple pickle data files.
|
|
|
|
Args:
|
|
filenames (List[str]): List of all the filename to load.
|
|
save_dir (str, optional): Path of the files to load. Defaults to MODELS_DIR (see config.py).
|
|
|
|
Returns:
|
|
List[Any]. List of loaded pickle data files.
|
|
"""
|
|
b = []
|
|
for f in filenames:
|
|
filepath = f"{save_dir}/{f}.pkl"
|
|
if os.path.exists(filepath):
|
|
with open(filepath, "rb") as filebyte:
|
|
b.append(pickle.load(filebyte))
|
|
else:
|
|
b.append(None)
|
|
return b
|
|
|
|
def benchmark_function(step_name: str, fnc: Callable) -> Any:
|
|
"""Benchmark a function and display the result of stdout.
|
|
|
|
Args:
|
|
step_name (str): Name of the function to call.
|
|
fnc (Callable): Function to call.
|
|
|
|
Returns:
|
|
Any: Result of the function.
|
|
"""
|
|
print(f"{step_name}...", end = "\r")
|
|
s = perf_counter_ns()
|
|
b = fnc()
|
|
e = perf_counter_ns() - s
|
|
print(f"| {step_name:<49} | {e:>18,} | {format_time_ns(e):<29} |")
|
|
return b
|
|
|
|
def state_saver(step_name: str, filename: Union[str, List[str]], fnc, force_redo: bool = False, save_state: bool = True, save_dir: str = OUT_DIR) -> Any:
|
|
"""Either execute a function then saves the result or load the already existing result.
|
|
|
|
Args:
|
|
step_name (str): Name of the function to call.
|
|
filename (Union[str, List[str]]): Name or list of names of the filenames where the result(s) are saved.
|
|
fnc ([type]): Function to call.
|
|
force_redo (bool, optional): Recall the function even if the result(s) is already saved. Defaults to False.
|
|
save_dir (str, optional): Path of the directory to save the result(s). Defaults to OUT_DIR (see config.py).
|
|
|
|
Returns:
|
|
Any: The result(s) of the called function
|
|
"""
|
|
if isinstance(filename, str):
|
|
if not os.path.exists(f"{save_dir}/{filename}.pkl") or force_redo:
|
|
b = benchmark_function(step_name, fnc)
|
|
if save_state:
|
|
print(f"Saving results of {step_name}", end = '\r')
|
|
with open(f"{save_dir}/{filename}.pkl", 'wb') as f:
|
|
pickle.dump(b, f)
|
|
print(' ' * 100, end = '\r')
|
|
return b
|
|
else:
|
|
print(f"Loading results of {step_name}", end = '\r')
|
|
with open(f"{save_dir}/{filename}.pkl", "rb") as f:
|
|
res = pickle.load(f)
|
|
print(f"| {step_name:<49} | {'None':>18} | {'loaded saved state':<29} |")
|
|
return res
|
|
elif isinstance(filename, list):
|
|
abs = False
|
|
for fn in filename:
|
|
if not os.path.exists(f"{save_dir}/{fn}.pkl"):
|
|
abs = True
|
|
break
|
|
if abs or force_redo:
|
|
b = benchmark_function(step_name, fnc)
|
|
if save_state:
|
|
print(f"Saving results of {step_name}", end = '\r')
|
|
for bi, fnI in zip(b, filename):
|
|
with open(f"{save_dir}/{fnI}.pkl", 'wb') as f:
|
|
pickle.dump(bi, f)
|
|
print(' ' * 100, end = '\r')
|
|
return b
|
|
|
|
print(f"| {step_name:<49} | {'None':>18} | {'loaded saved state':<29} |")
|
|
b = []
|
|
print(f"Loading results of {step_name}", end = '\r')
|
|
for fn in filename:
|
|
with open(f"{save_dir}/{fn}.pkl", "rb") as f:
|
|
b.append(pickle.load(f))
|
|
print(' ' * 100, end = '\r')
|
|
return b
|
|
else:
|
|
assert False, f"Incompatible filename type = {type(filename)}"
|
|
|
|
@njit('boolean(int32[:, :], uint16[:, :])')
|
|
def unit_test_argsort_2d(arr: np.ndarray, indices: np.ndarray) -> bool:
|
|
n = indices.shape[0]
|
|
total = indices.shape[0] * indices.shape[1]
|
|
for i, sub_indices in enumerate(indices):
|
|
for j in range(sub_indices.shape[0] - 1):
|
|
if arr[i, sub_indices[j]] <= arr[i, sub_indices[j + 1]]:
|
|
n += 1
|
|
if n != total:
|
|
print(n, total, n / (total))
|
|
return n == total
|