#pragma once
#include <filesystem>
#include "data.hpp"
#include "toolbox.hpp"

/**
 * @brief Test if a array from a CPU computation is equal to a GPU computation equivalent.
 *
 * @tparam T Inner type of the arrays to test
 * @param cpu CPU Array
 * @param gpu GPU Array
 * @return Whether the test was successful
 */
template <typename T>
bool unit_test_cpu_vs_gpu(const np::Array<T>& cpu, const np::Array<T>& gpu) noexcept {
	if (cpu.shape != gpu.shape) {
#if __DEBUG
		fprintf(stderr, "Unequal shape !\n");
#endif
		return false;
	}
	size_t eq = 0;
	const size_t length = np::prod(cpu.shape);
	for (size_t i = 0; i < length; ++i)
		if (cpu[i] == gpu[i])
			++eq;

#if __DEBUG
	if (eq != length)
		printf("Incorrect results, Number of equalities : %s/%s <=> %.2f%% !\n", thousand_sep(eq).c_str(), thousand_sep(length).c_str(),
				static_cast<float64_t>(eq) / static_cast<float64_t>(length) * 100.0);
#endif

	return eq == length;
}

/**
 * @brief Test if a given 2D array of indices sort a given 2D array
 *
 * @tparam T Inner type of the array to test
 * @param a 2D Array of data
 * @param indices 2D Indices that sort the array
 * @return Whether the test was successful
 */
template <typename T>
bool unit_test_argsort_2d(const np::Array<T>& a, const np::Array<uint16_t>& indices) noexcept {
	if (a.shape != indices.shape) {
#if __DEBUG
		fprintf(stderr, "Unequal shape !\n");
#endif
		return false;
	}
	size_t correct = a.shape[0]; // First elements are always correctly sorted
	const size_t total = np::prod(a.shape);
	for(size_t i = 0; i < total; i += a.shape[1])
		for(size_t j = 0; j < a.shape[1] - 1; ++j){
			const size_t k = i + j;
			if(a[i + indices[k]] <= a[i + indices[k + 1]])
				++correct;
		}
#if __DEBUG
	if (correct != total)
		printf("Incorrect results, Number of equalities : %s/%s <=> %.2f%% !\n", thousand_sep(correct).c_str(), thousand_sep(total).c_str(),
				static_cast<float64_t>(correct) / static_cast<float64_t>(total) * 100.0);
#endif
	return correct == total;
}

/**
 * @brief Benchmark a function and display the result in stdout.
 *
 * @tparam T Resulting type of the function to benchmark
 * @tparam F Signature of the function to call
 * @tparam Args Arguments variadic of the function to call
 * @param step_name Name of the function to log
 * @param column_width Width of the column to print during logging
 * @param fnc Function to benchmark
 * @param args Arguments to pass to the function to call
 * @return Result of the benchmarked function
 */
template <typename T, typename F, typename... Args>
T benchmark_function(const char* const step_name, const int32_t& column_width, const F& fnc, Args &&...args) noexcept {
#if __DEBUG == false
	fprintf(stderr, "%s...\r", step_name);
	fflush(stderr); // manual flush is mandatory, otherwise it will not be shown immediately because the output is buffered
#endif
	const std::chrono::system_clock::time_point start = perf_counter_ns();
	const T res = fnc(std::forward<Args>(args)...);
	const long long time_spent = duration_ns(perf_counter_ns() - start);
	formatted_row<3>({ column_width, -18, 29 }, { step_name, thousand_sep(time_spent).c_str(), format_time_ns(time_spent).c_str() });
	return res;
}

/**
 * @brief Benchmark a function and display the result in stdout.
 *
 * @tparam F Signature of the function to call
 * @tparam Args Arguments variadic of the function to call
 * @param step_name Name of the function to log
 * @param column_width Width of the column to print during logging
 * @param fnc Function to benchmark
 * @param args Arguments to pass to the function to call
 */
template <typename F, typename... Args>
void benchmark_function_void(const char* const step_name, const int32_t& column_width, const F& fnc, Args &&...args) noexcept {
#if __DEBUG == false
	fprintf(stderr, "%s...\r", step_name);
	fflush(stderr); // manual flush is mandatory, otherwise it will not be shown immediately because the output is buffered
#endif
	const std::chrono::system_clock::time_point start = perf_counter_ns();
	fnc(std::forward<Args>(args)...);
	const long long time_spent = duration_ns(perf_counter_ns() - start);
	formatted_row<3>({ column_width, -18, 29 }, { step_name, thousand_sep(time_spent).c_str(), format_time_ns(time_spent).c_str() });
}

/**
 * @brief Either execute a function then save the result or load the already cached result.
 *
 * @tparam T Inner type of the resulting array
 * @tparam F Signature of the function to call
 * @tparam Args Arguments variadic of the function to call
 * @param step_name Name of the function to log
 * @param column_width Width of the column to print during logging
 * @param filename Name of the filename where the result is saved
 * @param force_redo Recall the function even if the result is already saved, ignored if result is not cached
 * @param save_state Whether the computed result will be saved or not, ignore if loading already cached result
 * @param out_dir Path of the directory to save the result
 * @param fnc Function to call
 * @param args Arguments to pass to the function to call
 * @return The result of the called function
 */
template <typename T, typename F, typename... Args>
np::Array<T> state_saver(const char* const step_name, const int32_t& column_width, const char* const filename, const bool& force_redo, const bool& save_state, const char* const out_dir, const F& fnc, Args &&...args) noexcept {
	char filepath[BUFFER_SIZE] = { 0 };
	snprintf(filepath, BUFFER_SIZE, "%s/%s.bin", out_dir, filename);

	np::Array<T> bin;
	if (!std::filesystem::exists(filepath) || force_redo) {
		bin = benchmark_function<np::Array<T>>(step_name, column_width, fnc, std::forward<Args>(args)...);
		if(save_state){
#if __DEBUG == false
			fprintf(stderr, "Saving results of %s\r", step_name);
			fflush(stderr);
#endif
			save<T>(bin, filepath);
#if __DEBUG == false
			fprintf(stderr, "%*c\r", 100, ' '); // Clear previous clear
			fflush(stderr);
#endif
		}
	} else {
#if __DEBUG == false
		fprintf(stderr, "Loading results of %s\r", step_name);
		fflush(stderr);
#endif
		bin = load<T>(filepath);
		formatted_row<3>({ column_width, -18, 29 }, { step_name, "None", "loaded saved state" });
	}
	return bin;
}

/**
 * @brief Either execute a function then saves the results or load the already cached result.
 *
 * @tparam T Inner type of the resulting arrays
 * @tparam F Signature of the function to call
 * @tparam Args Arguments variadic of the function to call
 * @param step_name Name of the function to log
 * @param column_width Width of the column to print during logging
 * @param filenames List of names of the filenames where the results are save
 * @param force_redo Recall the function even if the results are already saved, ignored if results are not cached
 * @param save_state Whether the computed results will be saved or not, ignored if loading already cached results
 * @param out_dir Path of the directory to save the results
 * @param fnc Function to call
 * @param args Arguments to pass to the function to call
 * @return The results of the called function
 */
template <typename T, size_t N, typename F, typename... Args>
std::array<np::Array<T>, N> state_saver(const char* const step_name, const int32_t& column_width, const std::vector<const char*>& filenames, const bool& force_redo, const bool& save_state, const char* const out_dir, const F& fnc, Args &&...args) noexcept {
	char filepath[BUFFER_SIZE] = { 0 };
	bool abs = false;
	for (const char* const filename : filenames){
		snprintf(filepath, BUFFER_SIZE, "%s/%s.bin", out_dir, filename);
		if (!std::filesystem::exists(filepath)) {
			abs = true;
			break;
		}
	}

	std::array<np::Array<T>, N> bin;
	if (abs || force_redo) {
		bin = benchmark_function<std::array<np::Array<T>, N>>(step_name, column_width, fnc, std::forward<Args>(args)...);
		if (save_state){
#if __DEBUG == false
			fprintf(stderr, "Saving results of %s\r", step_name);
			fflush(stderr);
#endif
			size_t i = 0;
			for (const char* const filename : filenames){
				snprintf(filepath, BUFFER_SIZE, "%s/%s.bin", out_dir, filename);
				save<T>(bin[i++], filepath);
			}
#if __DEBUG == false
			fprintf(stderr, "%*c\r", 100, ' '); // Clear previous print
			fflush(stderr);
#endif
		}
	} else {
#if __DEBUG == false
		fprintf(stderr, "Loading results of %s\r", step_name);
		fflush(stderr);
#endif
		size_t i = 0;
		for (const char* const filename : filenames){
			snprintf(filepath, BUFFER_SIZE, "%s/%s.bin", out_dir, filename);
			bin[i++] = load<T>(filepath);
		}
		formatted_row<3>({ column_width, -18, 29 }, { step_name, "None", "loaded saved state" });
	}
	return bin;
}

/**
 * @brief Initialize the features based on the input shape.
 *
 * @param width Width of the image
 * @param height Height of the image
 * @return The initialized features
 */
np::Array<uint8_t> build_features(const uint16_t&, const uint16_t&) noexcept;
//np::Array<int32_t> select_percentile(const np::Array<uint8_t>&, const np::Array<uint8_t>&) noexcept;

/**
 * @brief Classify the trained classifiers on the given features.
 *
 * @param alphas Trained alphas
 * @param classifiers Trained classifiers
 * @param X_feat integrated features
 * @return Classification results
 */
np::Array<uint8_t> classify_viola_jones(const np::Array<float64_t>&, const np::Array<float64_t>&, const np::Array<int32_t>&) noexcept;

/**
 * @brief Initialize the weights of the weak classifiers based on the training labels.
 *
 * @param y_train Training labels
 * @return The initialized weights
 */
np::Array<float64_t> init_weights(const np::Array<uint8_t>&) noexcept;

/**
 * @brief Select the best classifier given their predictions.
 *
 * @param classifiers The weak classifiers
 * @param weights Trained weights of each classifiers
 * @param X_feat Integrated features
 * @param y Features labels
 * @return Index of the best classifier, the best error and the best accuracy
 */
std::tuple<int32_t, float64_t, np::Array<float64_t>> select_best(const np::Array<float64_t>&, const np::Array<float64_t>&, const np::Array<int32_t>&,
								const np::Array<uint8_t>&) noexcept;

/**
 * @brief Train the weak classifiers.
 *
 * @param T Number of weak classifiers
 * @param X_feat Integrated features
 * @param X_feat_argsort Sorted indexes of the integrated features
 * @param y Features labels
 * @return List of trained alphas and the list of the final classifiers
 */
std::array<np::Array<float64_t>, 2> train_viola_jones(const size_t&, const np::Array<int32_t>&, const np::Array<uint16_t>&, const np::Array<uint8_t>&) noexcept;

/**
 * @brief Compute the accuracy score i.e. how a given set of measurements are close to their true value.
 *
 * @param y Ground truth labels
 * @param y_pred Predicted labels
 * @return computed accuracy score
 */
float64_t accuracy_score(const np::Array<uint8_t>&, const np::Array<uint8_t>&) noexcept;

/**
 * @brief Compute the precision score i.e. how a given set of measurements are close to each other.
 *
 * @param y Ground truth labels
 * @param y_pred Predicted labels
 * @return computed precision score
 */
float64_t precision_score(const np::Array<uint8_t>&, const np::Array<uint8_t>&) noexcept;

/**
 * @brief Compute the recall score i.e. the ratio (TP / (TP + FN)) where TP is the number of true positives and FN the number of false negatives.
 *
 * @param y Ground truth labels
 * @param y_pred Predicted labels
 * @return computed recall score
 */
float64_t recall_score(const np::Array<uint8_t>&, const np::Array<uint8_t>&) noexcept;

/**
 * @brief Compute the F1 score aka balanced F-score or F-measure.
 *
 * F1 = (2 * TP) / (2 * TP + FP + FN)
 * where TP is the true positives,
 * FP is the false positives,
 * and FN is the false negatives
 *
 * @param y Ground truth labels
 * @param y_pred Predicted labels
 * @return computed F1 score
 */
float64_t f1_score(const np::Array<uint8_t>&, const np::Array<uint8_t>&) noexcept;

/**
 * @brief Compute the confusion matrix to evaluate a given classification.
 *
 * A confusion matrix of a binary classification consists of a 2x2 matrix containing
 * | True negatives  | False positives |
 * | False negatives | True positives  |
 *
 * @param y Ground truth labels
 * @param y_pred Predicted labels
 * @return computed confusion matrix
 */
std::tuple<uint16_t, uint16_t, uint16_t, uint16_t> confusion_matrix(const np::Array<uint8_t>&, const np::Array<uint8_t>&) noexcept;