from config import COMPILE_WITH_C
from typing import Iterable
from numba import int32, float64
import numpy as np

if COMPILE_WITH_C:
	from numba import njit
	@njit
	def tqdm_iter(iter: Iterable, _: str):
		return iter
else:
	from decorators import njit, tqdm_iter
	import sys
	sys.setrecursionlimit(10000)

@njit('uint32[:, :, :](uint8[:, :, :])')
def set_integral_image(X: np.ndarray) -> np.ndarray:
	"""Transform the input images in integrated images (CPU version).

	Args:
		X (np.ndarray): Dataset of images

	Returns:
		np.ndarray: Dataset of integrated images
	"""
	X_ii = np.empty_like(X, dtype = np.uint32)
	for i, Xi in enumerate(tqdm_iter(X, "Applying integral image")):
		ii = np.zeros_like(Xi, dtype = np.uint32)
		for y in range(1, Xi.shape[0]):
			s = 0
			for x in range(Xi.shape[1] - 1):
				s += Xi[y - 1, x]
				ii[y, x + 1] = s + ii[y - 1, x + 1]
		X_ii[i] = ii
	return X_ii

@njit('int32[:, :](int32[:, :], uint16[:, :], uint8[:], float64[:])')
def train_weak_clf(X_feat: np.ndarray, X_feat_argsort: np.ndarray, y: np.ndarray, weights: np.ndarray) -> np.ndarray:
	"""Train the weak classifiers on a given dataset (CPU version).

	Args:
		X_feat (np.ndarray): Feature images dataset
		X_feat_argsort (np.ndarray): Sorted indexes of the integrated features
		y (np.ndarray): Labels of the features
		weights (np.ndarray): Weights of the features

	Returns:
		np.ndarray: Trained weak classifiers
	"""
	total_pos, total_neg = weights[y == 1].sum(), weights[y == 0].sum()

	classifiers = np.empty((X_feat.shape[0], 2), dtype = np.int32)
	for i, feature in enumerate(tqdm_iter(X_feat, "Training weak classifiers")):
		pos_seen, neg_seen = 0, 0
		pos_weights, neg_weights = 0, 0
		min_error, best_threshold, best_polarity = float64(np.inf), 0, 0
		for j in X_feat_argsort[i]:
			error = min(neg_weights + total_pos - pos_weights, pos_weights + total_neg - neg_weights)
			if error < min_error:
				min_error = error
				best_threshold = feature[j]
				best_polarity = 1 if pos_seen > neg_seen else -1

			if y[j] == 1:
				pos_seen += 1
				pos_weights += weights[j]
			else:
				neg_seen += 1
				neg_weights += weights[j]

		classifiers[i] = (best_threshold, best_polarity)
	return classifiers

@njit('uint32(uint32[:, :], int16, int16, int16, int16)')
def __compute_feature__(ii: np.ndarray, x: int, y: int, w: int, h: int) -> int:
	"""Compute a feature on an integrated image at a specific coordinate (CPU version).

	Args:
		ii (np.ndarray): Integrated image
		x (int): X coordinate
		y (int): Y coordinate
		w (int): width of the feature
		h (int): height of the feature

	Returns:
		int: Computed feature
	"""
	return ii[y + h, x + w] + ii[y, x] - ii[y + h, x] - ii[y, x + w]

@njit('int32[:, :](uint8[:, :, :, :], uint32[:, :, :])')
def apply_features(feats: np.ndarray, X_ii: np.ndarray) -> np.ndarray:
	"""Apply the features on a integrated image dataset (CPU version).

	Args:
		feats (np.ndarray): Features to apply
		X_ii (np.ndarray): Integrated image dataset

	Returns:
		np.ndarray: Applied features
	"""
	X_feat = np.empty((feats.shape[0], X_ii.shape[0]), dtype = np.int32)

	for i, (p, n) in enumerate(tqdm_iter(feats, "Applying features")):
		for j, x_i in enumerate(X_ii):
			p_x, p_y, p_w, p_h = p[0]
			p1_x, p1_y, p1_w, p1_h = p[1]
			n_x, n_y, n_w, n_h = n[0]
			n1_x, n1_y, n1_w, n1_h = n[1]
			p1 = __compute_feature__(x_i, p_x, p_y, p_w, p_h) + __compute_feature__(x_i, p1_x, p1_y, p1_w, p1_h)
			n1 = __compute_feature__(x_i, n_x, n_y, n_w, n_h) + __compute_feature__(x_i, n1_x, n1_y, n1_w, n1_h)
			X_feat[i, j] = int32(p1) - int32(n1)

	return X_feat

@njit('int32(int32[:], uint16[:], int32, int32)')
def _as_partition_(d_a: np.ndarray, d_indices: np.ndarray, low: int, high: int) -> int:
	"""Partition of the argsort algorithm.

	Args:
		d_a (np.ndarray): Array on device to sort
		d_indices (np.ndarray): Array of indices on device to write to
		low (int): lower bound to sort
		high (int): higher bound to sort

	Returns:
		int: Last index sorted
	"""
	i, j = low - 1, low
	for j in range(low, high + 1):
		if d_a[d_indices[j]] < d_a[d_indices[high]]:
			i += 1
			d_indices[i], d_indices[j] = d_indices[j], d_indices[i]

	i += 1
	d_indices[i], d_indices[j] = d_indices[j], d_indices[i]
	return i

@njit('void(int32[:], uint16[:], int32, int32)')
def argsort_bounded(d_a: np.ndarray, d_indices: np.ndarray, low: int, high: int) -> None:
	"""Perform an indirect sort of a given array within a given bound.

	Args:
		d_a (np.ndarray): Array to sort
		d_indices (np.ndarray): Array of indices to write to
		low (int): lower bound to sort
		high (int): higher bound to sort
	"""
	total = high - low + 1
	stack = np.empty((total,), dtype = np.int32)
	stack[0] = low
	stack[1] = high
	top = 1

	while top >= 0:
		high = stack[top]
		top -= 1
		low = stack[top]
		top -= 1

		if low >= high:
			break

		p = _as_partition_(d_a, d_indices, low, high)

		if p - 1 > low:
			top += 1
			stack[top] = low
			top += 1
			stack[top] = p - 1

		if p + 1 < high:
			top += 1
			stack[top] = p + 1
			top += 1
			stack[top] = high

@njit('uint16[:, :](int32[:, :])')
def argsort(X_feat: np.ndarray) -> np.ndarray:
	"""Perform an indirect sort of a given array.

	Args:
		X_feat (np.ndarray): Array to sort

	Returns:
		np.ndarray: Array of indices that sort the array
	"""
	indices = np.empty_like(X_feat, dtype = np.uint16)
	indices[:, :] = np.arange(indices.shape[1])
	for i in tqdm_iter(range(X_feat.shape[0]), "argsort"):
		argsort_bounded(X_feat[i], indices[i], 0, X_feat[i].shape[0] - 1)
	return indices