from typing import Any, Callable, List, Union
from time import perf_counter_ns
from numba import njit
import numpy as np
import pickle
import os
from config import MODEL_DIR, OUT_DIR

formats = ["ns", "µs", "ms", "s", "m", "h", "j", "w", "M", "y"]
nb = np.array([1, 1000, 1000, 1000, 60, 60, 24, 7, 4, 12], dtype = np.uint16)
def format_time_ns(time: int) -> str:
	"""Format the time in nanoseconds in human readable format.

	Args:
		time (int): Time in nanoseconds.

	Returns:
		str: The formatted human readable string.
	"""
	assert time >= 0, "Incorrect time stamp"
	if time == 0:
		return "0ns"
	prod = nb.prod(dtype = np.uint64)

	s = ""
	for i in range(nb.shape[0])[::-1]:
		if time >= prod:
			res = int(time // prod)
			time = time % prod
			s += f"{res}{formats[i]} "
		prod = prod // nb[i]

	assert time == 0, "Leftover in formatting time !"
	return s.rstrip()

def toolbox_unit_test() -> None:
	# FIXME Move unit test to different file
	assert "0ns" == format_time_ns(0)
	assert "1ns" == format_time_ns(1)
	assert "1µs" == format_time_ns(int(1e3))
	assert "1ms" == format_time_ns(int(1e6))
	assert "1s" == format_time_ns(int(1e9))
	assert "1m" == format_time_ns(int(6e10))
	assert "1h" == format_time_ns(int(36e11))
	assert "1j" == format_time_ns(int(864e11))
	assert "1w" == format_time_ns(int(6048e11))
	assert "1M" == format_time_ns(int(24192e11))
	assert "1y" == format_time_ns(int(290304e11))
	# UINT64_MAX == 2^64 = 18446744073709551615 == -1
	assert "635y 5M 3j 23h 34m 33s 709ms 551µs 616ns" == format_time_ns(2**64)

def picke_multi_loader(filenames: List[str], save_dir: str = MODEL_DIR) -> List[Any]:
	"""Load multiple pickle data files.

	Args:
		filenames (List[str]): List of all the filename to load.
		save_dir (str, optional): Path of the files to load. Defaults to MODELS_DIR (see config.py).

	Returns:
		List[Any]. List of loaded pickle data files.
	"""
	b = []
	for f in filenames:
		filepath = f"{save_dir}/{f}.pkl"
		if os.path.exists(filepath):
			with open(filepath, "rb") as filebyte:
				b.append(pickle.load(filebyte))
		else:
			b.append(None)
	return b

def benchmark_function(step_name: str, fnc: Callable) -> Any:
	"""Benchmark a function and display the result of stdout.

	Args:
		step_name (str): Name of the function to call.
		fnc (Callable): Function to call.

	Returns:
		Any: Result of the function.
	"""
	print(f"{step_name}...", end = "\r")
	s = perf_counter_ns()
	b = fnc()
	e = perf_counter_ns() - s
	print(f"| {step_name:<49} | {e:>18,} | {format_time_ns(e):<29} |")
	return b

def state_saver(step_name: str, filename: Union[str, List[str]], fnc, force_redo: bool = False, save_state: bool = True, save_dir: str = OUT_DIR) -> Any:
	"""Either execute a function then saves the result or load the already existing result.

	Args:
		step_name (str): Name of the function to call.
		filename (Union[str, List[str]]): Name or list of names of the filenames where the result(s) are saved.
		fnc ([type]): Function to call.
		force_redo (bool, optional): Recall the function even if the result(s) is already saved. Defaults to False.
		save_dir (str, optional): Path of the directory to save the result(s). Defaults to OUT_DIR (see config.py).

	Returns:
		Any: The result(s) of the called function
	"""
	if isinstance(filename, str):
		if not os.path.exists(f"{save_dir}/{filename}.pkl") or force_redo:
			b = benchmark_function(step_name, fnc)
			if save_state:
				print(f"Saving results of {step_name}", end = '\r')
				with open(f"{save_dir}/{filename}.pkl", 'wb') as f:
					pickle.dump(b, f)
				print(' ' * 100, end = '\r')
			return b
		else:
			print(f"Loading results of {step_name}", end = '\r')
			with open(f"{save_dir}/{filename}.pkl", "rb") as f:
				res = pickle.load(f)
			print(f"| {step_name:<49} | {'None':>18} | {'loaded saved state':<29} |")
			return res
	elif isinstance(filename, list):
		abs = False
		for fn in filename:
			if not os.path.exists(f"{save_dir}/{fn}.pkl"):
				abs = True
				break
		if abs or force_redo:
			b = benchmark_function(step_name, fnc)
			if save_state:
				print(f"Saving results of {step_name}", end = '\r')
				for bi, fnI in zip(b, filename):
					with open(f"{save_dir}/{fnI}.pkl", 'wb') as f:
						pickle.dump(bi, f)
				print(' ' * 100, end = '\r')
			return b

		print(f"| {step_name:<49} | {'None':>18} | {'loaded saved state':<29} |")
		b = []
		print(f"Loading results of {step_name}", end = '\r')
		for fn in filename:
			with open(f"{save_dir}/{fn}.pkl", "rb") as f:
				b.append(pickle.load(f))
		print(' ' * 100, end = '\r')
		return b
	else:
		assert False, f"Incompatible filename type = {type(filename)}"

@njit('boolean(int32[:, :], uint16[:, :])')
def unit_test_argsort_2d(arr: np.ndarray, indices: np.ndarray) -> bool:
	n = indices.shape[0]
	total = indices.shape[0] * indices.shape[1]
	for i, sub_indices in enumerate(indices):
		for j in range(sub_indices.shape[0] - 1):
			if arr[i, sub_indices[j]] <= arr[i, sub_indices[j + 1]]:
				n += 1
	if n != total:
		print(n, total, n / (total))
	return n == total