#!/usr/bin/env python
# Author: @saundersp

from ViolaJones import train_viola_jones, classify_viola_jones
#from toolbox import state_saver, pickle_multi_loader, format_time_ns, benchmark_function, unit_test_argsort_2d
from toolbox import state_saver, format_time_ns, benchmark_function, unit_test_argsort_2d
from toolbox import header, footer, formatted_row, formatted_line
from toolbox_unit_test import format_time_test, format_time_ns_test
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
#from sklearn.feature_selection import SelectPercentile, f_classif
from common import load_datasets, unit_test
from ViolaJones import build_features  # , get_best_anova_features
from typing import Tuple, List
from time import perf_counter_ns
from os import makedirs
import numpy as np

from config import FORCE_REDO, COMPILE_WITH_C, GPU_BOOSTED, TS, SAVE_STATE, MODEL_DIR, __DEBUG
if __DEBUG:
	from config import IDX_INSPECT, IDX_INSPECT_OFFSET

if GPU_BOOSTED:
	from ViolaJonesGPU import apply_features, set_integral_image, argsort_2d
	label = 'GPU' if COMPILE_WITH_C else 'PGPU'
	# The parallel prefix sum doesn't use the whole GPU so numba output some annoying warnings, this disables it
	from numba import config
	config.CUDA_LOW_OCCUPANCY_WARNINGS = 0
else:
	from ViolaJonesCPU import apply_features, set_integral_image, argsort_2d
	label = 'CPU' if COMPILE_WITH_C else 'PY'

def preprocessing() -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
	"""Execute the preprocessing phase

	The preprocessing phase consist of the following steps :
	- Load the dataset
	- Calculate features
	- Calculate integral images
	- Apply features to images
	- Calculate argsort of the featured images

	Returns:
		Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: Tuple containing in order : training features, training features sorted indexes, training labels, testing features, testing labels
	"""
	# Creating state saver folders if they don't exist already
	if SAVE_STATE:
		for folder_name in ['models', 'out']:
			makedirs(folder_name, exist_ok = True)

	preproc_timestamp = perf_counter_ns()
	preproc_gaps = [49, -18, 29]
	header(preproc_gaps, ['Preprocessing', 'Time spent (ns)', 'Formatted time spent'])

	X_train, y_train, X_test, y_test = state_saver('Loading sets', preproc_gaps[0], ['X_train', 'y_train', 'X_test', 'y_test'],
													load_datasets, FORCE_REDO, SAVE_STATE)

	if __DEBUG:
		print('X_train')
		print(X_train.shape)
		print(X_train[IDX_INSPECT])
		print('X_test')
		print(X_test.shape)
		print(X_test[IDX_INSPECT])
		print('y_train')
		print(y_train.shape)
		print(y_train[IDX_INSPECT: IDX_INSPECT + IDX_INSPECT_OFFSET])
		print('y_test')
		print(y_test.shape)
		print(y_test[IDX_INSPECT: IDX_INSPECT + IDX_INSPECT_OFFSET])

	feats = state_saver('Building features', preproc_gaps[0], 'feats', lambda: build_features(X_train.shape[1], X_train.shape[2]),
						FORCE_REDO, SAVE_STATE)

	if __DEBUG:
		print('feats')
		print(feats.shape)
		print(feats[IDX_INSPECT].ravel())

	X_train_ii = state_saver(f'Converting training set to integral images ({label})', preproc_gaps[0], f'X_train_ii_{label}',
							lambda: set_integral_image(X_train), FORCE_REDO, SAVE_STATE)
	X_test_ii = state_saver(f'Converting testing set to integral images ({label})', preproc_gaps[0], f'X_test_ii_{label}',
							lambda: set_integral_image(X_test), FORCE_REDO, SAVE_STATE)

	if __DEBUG:
		print('X_train_ii')
		print(X_train_ii.shape)
		print(X_train_ii[IDX_INSPECT])
		print('X_test_ii')
		print(X_test_ii.shape)
		print(X_test_ii[IDX_INSPECT])

	X_train_feat = state_saver(f'Applying features to training set ({label})', preproc_gaps[0], f'X_train_feat_{label}',
							lambda: apply_features(feats, X_train_ii), FORCE_REDO, SAVE_STATE)
	X_test_feat = state_saver(f'Applying features to testing set ({label})', preproc_gaps[0], f'X_test_feat_{label}',
							lambda: apply_features(feats, X_test_ii), FORCE_REDO, SAVE_STATE)
	del X_train_ii, X_test_ii, feats

	if __DEBUG:
		print('X_train_feat')
		print(X_train_feat.shape)
		print(X_train_feat[IDX_INSPECT, : IDX_INSPECT_OFFSET])
		print('X_test_feat')
		print(X_test_feat.shape)
		print(X_test_feat[IDX_INSPECT, : IDX_INSPECT_OFFSET])

	#indices = state_saver('Selecting best features training set', 'indices', force_redo = FORCE_REDO, save_state = SAVE_STATE,
	#						fnc = lambda: SelectPercentile(f_classif, percentile = 10).fit(X_train_feat.T, y_train).get_support(indices = True))
	#indices = state_saver('Selecting best features training set', 'indices', force_redo = FORCE_REDO, save_state = SAVE_STATE,
	#						fnc = lambda: get_best_anova_features(X_train_feat, y_train))
	#indices = benchmark_function('Selecting best features (manual)', lambda: get_best_anova_features(X_train_feat, y_train))

	#if __DEBUG:
	#	print('indices')
	#	print(indices.shape)
	#	print(indices[IDX_INSPECT: IDX_INSPECT + IDX_INSPECT_OFFSET])
	#	assert indices.shape[0] == indices_new.shape[0], f'Indices length not equal : {indices.shape} != {indices_new.shape}'
	#	assert (eq := indices == indices_new).all(), f'Indices not equal : {eq.sum() / indices.shape[0]}'

	# X_train_feat, X_test_feat = X_train_feat[indices], X_test_feat[indices]

	X_train_feat_argsort = state_saver(f'Precalculating training set argsort ({label})', preproc_gaps[0], f'X_train_feat_argsort_{label}',
									lambda: argsort_2d(X_train_feat), FORCE_REDO, SAVE_STATE)

	if __DEBUG:
		print('X_train_feat_argsort')
		print(X_train_feat_argsort.shape)
		print(X_train_feat_argsort[IDX_INSPECT, : IDX_INSPECT_OFFSET])
		benchmark_function('Arg unit test', preproc_gaps[0], lambda: unit_test_argsort_2d(X_train_feat, X_train_feat_argsort))

	X_test_feat_argsort = state_saver(f'Precalculating testing set argsort ({label})', preproc_gaps[0], f'X_test_feat_argsort_{label}',
									lambda: argsort_2d(X_test_feat), FORCE_REDO, SAVE_STATE)

	if __DEBUG:
		print('X_test_feat_argsort')
		print(X_test_feat_argsort.shape)
		print(X_test_feat_argsort[IDX_INSPECT, : IDX_INSPECT_OFFSET])
		benchmark_function('Arg unit test', lambda: unit_test_argsort_2d(X_test_feat, X_test_feat_argsort))

	time_spent = perf_counter_ns() - preproc_timestamp
	formatted_line(preproc_gaps, '├', '┼', '─', '┤')
	formatted_row(preproc_gaps, ['Preprocessing summary', f'{time_spent:,}', format_time_ns(time_spent)])
	footer(preproc_gaps)

	return X_train_feat, X_train_feat_argsort, y_train, X_test_feat, y_test

def train(X_train_feat: np.ndarray, X_train_feat_argsort: np.ndarray, y_train: np.ndarray) -> List[np.ndarray]:
	"""Train the weak classifiers.

	Args:
		X_train (np.ndarray): Training images
		X_train_feat_argsort (np.ndarray): Sorted indexes of the training images features
		y_train (np.ndarray): Training labels

	Returns:
		List[np.ndarray]: List of trained models
	"""

	training_timestamp = perf_counter_ns()
	training_gaps = [26, -18, 29]
	header(training_gaps, ['Training', 'Time spent (ns)', 'Formatted time spent'])
	models = []

	for T in TS:
		alphas, final_classifiers = state_saver(f'ViolaJones T = {T:<4} ({label})', training_gaps[0],
			[f'alphas_{T}_{label}', f'final_classifiers_{T}_{label}'],
			lambda: train_viola_jones(T, X_train_feat, X_train_feat_argsort, y_train), FORCE_REDO, SAVE_STATE, MODEL_DIR)
		models.append([alphas, final_classifiers])

		if __DEBUG:
			print('alphas')
			print(alphas)
			print('final_classifiers')
			print(final_classifiers)

	time_spent = perf_counter_ns() - training_timestamp
	formatted_line(training_gaps, '├', '┼', '─', '┤')
	formatted_row(training_gaps, ['Training summary', f'{time_spent:,}', format_time_ns(time_spent)])
	footer(training_gaps)

	return models

def testing_and_evaluating(models: List[np.ndarray], X_train_feat: np.ndarray, y_train: np.ndarray, X_test_feat: np.ndarray, y_test: np.ndarray) -> None:
	"""Benchmark the trained classifiers on the training and testing sets.

	Args:
		models (List[np.ndarray]): List of trained models
		X_train_feat (np.ndarray): Training features
		y_train (np.ndarray): Training labels
		X_test_feat (np.ndarray): Testing features
		y_test (np.ndarray): Testing labels
	"""

	testing_gaps = [26, -19, 24, -19, 24]
	header(testing_gaps, ['Testing', 'Time spent (ns) (E)', 'Formatted time spent (E)', 'Time spent (ns) (T)', 'Formatted time spent (T)'])

	performances = []
	total_train_timestamp = 0
	total_test_timestamp = 0
	for T, (alphas, final_classifiers) in zip(TS, models):
		s = perf_counter_ns()
		y_pred_train = classify_viola_jones(alphas, final_classifiers, X_train_feat)
		t_pred_train = perf_counter_ns() - s
		total_train_timestamp += t_pred_train
		e_acc = accuracy_score(y_train, y_pred_train)
		e_f1 = f1_score(y_train, y_pred_train)
		(_, e_FP), (e_FN, _) = confusion_matrix(y_train, y_pred_train)

		s = perf_counter_ns()
		y_pred_test = classify_viola_jones(alphas, final_classifiers, X_test_feat)
		t_pred_test = perf_counter_ns() - s
		total_test_timestamp += t_pred_test
		t_acc = accuracy_score(y_test, y_pred_test)
		t_f1 = f1_score(y_test, y_pred_test)
		(_, t_FP), (t_FN, _) = confusion_matrix(y_test, y_pred_test)
		performances.append((e_acc, e_f1, e_FN, e_FP, t_acc, t_f1, t_FN, t_FP))

		formatted_row(testing_gaps, [f"{'ViolaJones T = ' + str(T):<19} {'(' + label + ')':<6}", f'{t_pred_train:,}',
									format_time_ns(t_pred_train), f'{t_pred_test:,}', format_time_ns(t_pred_test)])

	formatted_line(testing_gaps, '├', '┼', '─', '┤')
	formatted_row(testing_gaps, ['Testing summary', f'{total_train_timestamp:,}', format_time_ns(total_train_timestamp), f'{total_test_timestamp:,}',
					format_time_ns(total_test_timestamp)])
	footer(testing_gaps)

	evaluating_gaps = [19, 7, 6, 6, 6, 7, 6, 6, 6]
	header(evaluating_gaps, ['Evaluating', 'ACC (E)', 'F1 (E)', 'FN (E)', 'FP (E)', 'ACC (T)', 'F1 (T)', 'FN (T)', 'FP (T)'])

	for T, (e_acc, e_f1, e_FN, e_FP, t_acc, t_f1, t_FN, t_FP) in zip(TS, performances):
		print(f'│ ViolaJones T = {T:<4} │ {e_acc:>7.2%} │ {e_f1:>6.2f} │ {e_FN:>6,} │ {e_FP:>6,}', end = ' │ ')
		print(f'{t_acc:>7.2%} │ {t_f1:>6.2f} │ {t_FN:>6,} │ {t_FP:>6,} │')

	footer(evaluating_gaps)

def main() -> None:
	unit_timestamp = perf_counter_ns()
	unit_gaps = [27, -18, 29]
	header(unit_gaps, ['Unit testing', 'Time spent (ns)', 'Formatted time spent'])
	benchmark_function('testing format_time', unit_gaps[0], format_time_test)
	benchmark_function('testing format_time_ns', unit_gaps[0], format_time_ns_test)
	time_spent = perf_counter_ns() - unit_timestamp
	formatted_line(unit_gaps, '├', '┼', '─', '┤')
	formatted_row(unit_gaps, ['Unit testing summary', f'{time_spent:,}', format_time_ns(time_spent)])
	footer(unit_gaps)

	X_train_feat, X_train_feat_argsort, y_train, X_test_feat, y_test = preprocessing()
	models = train(X_train_feat, X_train_feat_argsort, y_train)

	# X_train_feat, X_test_feat = pickle_multi_loader([f'X_train_feat_{label}', f'X_test_feat_{label}'], OUT_DIR)
	# indices = pickle_multi_loader(['indices'], OUT_DIR)[0]
	# X_train_feat, X_test_feat = X_train_feat[indices], X_test_feat[indices]

	testing_and_evaluating(models, X_train_feat, y_train, X_test_feat, y_test)
	unit_test(TS)

if __name__ == '__main__':
	main()