#!/usr/bin/env python
# Author: @saundersp

from ViolaJones import train_viola_jones, classify_viola_jones
from toolbox import state_saver, picke_multi_loader, format_time_ns, benchmark_function, toolbox_unit_test, unit_test_argsort_2d
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from sklearn.feature_selection import SelectPercentile, f_classif
from common import load_datasets, unit_test
from ViolaJones import build_features, get_best_anova_features
from typing import Tuple
from time import perf_counter_ns
from os import makedirs
import numpy as np

from config import FORCE_REDO, COMPILE_WITH_C, GPU_BOOSTED, TS, SAVE_STATE, MODEL_DIR, __DEBUG
if __DEBUG:
	from config import IDX_INSPECT, IDX_INSPECT_OFFSET

if GPU_BOOSTED:
	from ViolaJonesGPU import apply_features, set_integral_image, argsort
	label = 'GPU' if COMPILE_WITH_C else 'PGPU'
	# The parallel prefix sum doesn't use the whole GPU so numba output some annoying warnings, this disables it
	from numba import config
	config.CUDA_LOW_OCCUPANCY_WARNINGS = 0
else:
	from ViolaJonesCPU import apply_features, set_integral_image, argsort
	label = 'CPU' if COMPILE_WITH_C else 'PY'

def bench_train(X_train: np.ndarray, X_test: np.ndarray, y_train: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
	"""Train the weak classifiers.

	Args:
		X_train (np.ndarray): Training images.
		X_test (np.ndarray): Testing Images.
		y_train (np.ndarray): Training labels.

	Returns:
		Tuple[np.ndarray, np.ndarray]: Training and testing features.
	"""
	feats = state_saver("Building features", "feats", lambda: build_features(X_train.shape[1], X_train.shape[2]), FORCE_REDO, SAVE_STATE)

	if __DEBUG:
		print("feats")
		print(feats.shape)
		print(feats[IDX_INSPECT].ravel())

	X_train_ii = state_saver(f"Converting training set to integral images ({label})", f"X_train_ii_{label}",
							lambda: set_integral_image(X_train), FORCE_REDO, SAVE_STATE)
	X_test_ii = state_saver(f"Converting testing set to integral images ({label})", f"X_test_ii_{label}",
							lambda: set_integral_image(X_test), FORCE_REDO, SAVE_STATE)

	if __DEBUG:
		print("X_train_ii")
		print(X_train_ii.shape)
		print(X_train_ii[IDX_INSPECT])
		print("X_test_ii")
		print(X_test_ii.shape)
		print(X_test_ii[IDX_INSPECT])

	X_train_feat = state_saver(f"Applying features to training set ({label})", f"X_train_feat_{label}",
							lambda: apply_features(feats, X_train_ii), FORCE_REDO, SAVE_STATE)
	X_test_feat = state_saver(f"Applying features to testing set ({label})", f"X_test_feat_{label}",
							lambda: apply_features(feats, X_test_ii), FORCE_REDO, SAVE_STATE)
	del X_train_ii, X_test_ii, feats

	if __DEBUG:
		print("X_train_feat")
		print(X_train_feat.shape)
		print(X_train_feat[IDX_INSPECT, : IDX_INSPECT_OFFSET])
		print("X_test_feat")
		print(X_test_feat.shape)
		print(X_test_feat[IDX_INSPECT, : IDX_INSPECT_OFFSET])

	#indices = state_saver("Selecting best features training set", "indices", force_redo = True, save_state = SAVE_STATE,
	#						fnc = lambda: SelectPercentile(f_classif, percentile = 10).fit(X_train_feat.T, y_train).get_support(indices = True))
	#indices = state_saver("Selecting best features training set", "indices", force_redo = FORCE_REDO, save_state = SAVE_STATE,
	#						fnc = lambda: get_best_anova_features(X_train_feat, y_train))
	#indices = benchmark_function("Selecting best features (manual)", lambda: get_best_anova_features(X_train_feat, y_train))

	# FIXME Debug code
	# print("indices")
	# print(indices.shape)
	# print(indices[IDX_INSPECT: IDX_INSPECT + IDX_INSPECT_OFFSET])
	# assert indices.shape[0] == indices_new.shape[0], f"Indices length not equal : {indices.shape} != {indices_new.shape}"
	# assert (eq := indices == indices_new).all(), f"Indices not equal : {eq.sum() / indices.shape[0]}"
	# return 0, 0

	# X_train_feat, X_test_feat = X_train_feat[indices], X_test_feat[indices]

	X_train_feat_argsort = state_saver(f"Precalculating training set argsort ({label})", f"X_train_feat_argsort_{label}",
									lambda: argsort(X_train_feat), FORCE_REDO, SAVE_STATE)

	if __DEBUG:
		print("X_train_feat_argsort")
		print(X_train_feat_argsort.shape)
		print(X_train_feat_argsort[IDX_INSPECT, : IDX_INSPECT_OFFSET])
		benchmark_function("Arg unit test", lambda: unit_test_argsort_2d(X_train_feat, X_train_feat_argsort))

	X_test_feat_argsort = state_saver(f"Precalculating testing set argsort ({label})", f"X_test_feat_argsort_{label}",
									lambda: argsort(X_test_feat), FORCE_REDO, SAVE_STATE)

	if __DEBUG:
		print("X_test_feat_argsort")
		print(X_test_feat_argsort.shape)
		print(X_test_feat_argsort[IDX_INSPECT, : IDX_INSPECT_OFFSET])
		benchmark_function("Arg unit test", lambda: unit_test_argsort_2d(X_test_feat, X_test_feat_argsort))
		del X_test_feat_argsort

	print(f"\n| {'Training':<49} | {'Time spent (ns)':<18} | {'Formatted time spent':<29} |\n|{'-'*51}|{'-'*20}|{'-'*31}|")

	for T in TS:
		alphas, final_classifiers = state_saver(f"ViolaJones T = {T:<3} ({label})", [f"alphas_{T}_{label}", f"final_classifiers_{T}_{label}"],
					lambda: train_viola_jones(T, X_train_feat, X_train_feat_argsort, y_train), FORCE_REDO, SAVE_STATE, MODEL_DIR)
		if __DEBUG:
			print("alphas")
			print(alphas)
			print("final_classifiers")
			print(final_classifiers)

	return X_train_feat, X_test_feat

def bench_accuracy(label, X_train_feat: np.ndarray, X_test_feat: np.ndarray, y_train: np.ndarray, y_test: np.ndarray) -> None:
	"""Benchmark the trained classifiers on the training and testing sets.

	Args:
		X_train_feat (np.ndarray): Training features.
		X_test_feat (np.ndarray): Testing features.
		y_train (np.ndarray): Training labels.
		y_test (np.ndarray): Testing labels.
	"""
	print(f"\n| {'Testing':<26} | Time spent (ns) (E) | {'Formatted time spent (E)':<29}", end = " | ")
	print(f"Time spent (ns) (T) | {'Formatted time spent (T)':<29} |")
	print(f"|{'-'*28}|{'-'*21}|{'-'*31}|{'-'*21}|{'-'*31}|")

	perfs = []
	for T in TS:
		(alphas, final_classifiers) = picke_multi_loader([f"alphas_{T}_{label}", f"final_classifiers_{T}_{label}"])

		s = perf_counter_ns()
		y_pred_train = classify_viola_jones(alphas, final_classifiers, X_train_feat)
		t_pred_train = perf_counter_ns() - s
		e_acc = accuracy_score(y_train, y_pred_train)
		e_f1 = f1_score(y_train, y_pred_train)
		(_, e_FP), (e_FN, _) = confusion_matrix(y_train, y_pred_train)

		s = perf_counter_ns()
		y_pred_test = classify_viola_jones(alphas, final_classifiers, X_test_feat)
		t_pred_test = perf_counter_ns() - s
		t_acc = accuracy_score(y_test, y_pred_test)
		t_f1 = f1_score(y_test, y_pred_test)
		(_, t_FP), (t_FN, _) = confusion_matrix(y_test, y_pred_test)
		perfs.append((e_acc, e_f1, e_FN, e_FP, t_acc, t_f1, t_FN, t_FP))

		print(f"| {'ViolaJones T = ' + str(T):<19} {'(' + label + ')':<6}", end = " | ")
		print(f"{t_pred_train:>19,} | {format_time_ns(t_pred_train):<29}", end = " | ")
		print(f"{t_pred_test:>19,} | {format_time_ns(t_pred_test):<29} |")

	print(f"\n| {'Evaluating':<19} | ACC (E) | F1 (E) | FN (E) | FP (E) | ACC (T) | F1 (T) | FN (T) | FP (T) | ")
	print(f"|{'-'*21}|{'-'*9}|{'-'*8}|{'-'*8}|{'-'*8}|{'-'*9}|{'-'*8}|{'-'*8}|{'-'*8}|")

	for T, (e_acc, e_f1, e_FN, e_FP, t_acc, t_f1, t_FN, t_FP) in zip(TS, perfs):
		print(f"| {'ViolaJones T = ' + str(T):<19} | {e_acc:>7.2%} | {e_f1:>6.2f} | {e_FN:>6,} | {e_FP:>6,}", end = " | ")
		print(f"{t_acc:>7.2%} | {t_f1:>6.2f} | {t_FN:>6,} | {t_FP:>6,} |")

def _main_() -> None:

	# Creating state saver folders if they don't exist already
	if SAVE_STATE:
		for folder_name in ["models", "out"]:
			makedirs(folder_name, exist_ok = True)

	print(f"| {'Preprocessing':<49} | {'Time spent (ns)':<18} | {'Formatted time spent':<29} |\n|{'-'*51}|{'-'*20}|{'-'*31}|")

	X_train, y_train, X_test, y_test = state_saver("Loading sets", ["X_train", "y_train", "X_test", "y_test"],
	 											   load_datasets, FORCE_REDO, SAVE_STATE)

	if __DEBUG:
		print("X_train")
		print(X_train.shape)
		print(X_train[IDX_INSPECT])
		print("X_test")
		print(X_test.shape)
		print(X_test[IDX_INSPECT])
		print("y_train")
		print(y_train.shape)
		print(y_train[IDX_INSPECT: IDX_INSPECT + IDX_INSPECT_OFFSET])
		print("y_test")
		print(y_test.shape)
		print(y_test[IDX_INSPECT: IDX_INSPECT + IDX_INSPECT_OFFSET])

	X_train_feat, X_test_feat = bench_train(X_train, X_test, y_train)

	# X_train_feat, X_test_feat = picke_multi_loader([f"X_train_feat_{label}", f"X_test_feat_{label}"], OUT_DIR)
	# indices = picke_multi_loader(["indices"], OUT_DIR)[0]
	# X_train_feat, X_test_feat = X_train_feat[indices], X_test_feat[indices]

	bench_accuracy(label, X_train_feat, X_test_feat, y_train, y_test)

if __name__ == "__main__":
	_main_()
	if __DEBUG:
		toolbox_unit_test()

	# Only execute unit test after having trained the specified labels
	unit_test(TS, ["GPU", "CPU", "PY", "PGPU"])
	pass