from numba import float64, uint32, cuda, int32
from config import COMPILE_WITH_C, NB_THREADS, NB_THREADS_2D, NB_THREADS_3D, M
import numpy as np

if COMPILE_WITH_C:
	from numba import njit
else:
	from decorators import njit

@njit('uint32[:, :, :](uint32[:, :, :])')
def __scanCPU_3d__(X: np.ndarray) -> np.ndarray:
	"""Prefix Sum (scan) of a given dataset.

	Args:
		X (np.ndarray): Dataset of images to apply sum.

	Returns:
		np.ndarray: Scanned dataset of images.
	"""
	for x in range(X.shape[0]):
		for y in range(X.shape[1]):
			cum = 0
			for z in range(X.shape[2]):
				cum += X[x, y, z]
				X[x, y, z] = cum - X[x, y, z]
	return X

@cuda.jit('void(uint16, uint16, uint32[:, :, :], uint32[:, :, :])')
def __kernel_scan_3d__(n: int, j: int, d_inter: np.ndarray, d_a: np.ndarray) -> None:
	"""GPU kernel used to do a parallel prefix sum (scan).

	Args:
		n (int):
		j (int): [description]
		d_inter (np.ndarray): [description]
		d_a (np.ndarray): [description]
	"""
	x_coor, y_coor = cuda.grid(2)

	sA = cuda.shared.array(NB_THREADS_2D, uint32)
	sA[cuda.threadIdx.x, cuda.threadIdx.y] = d_a[cuda.blockIdx.z, y_coor, x_coor] if x_coor < n and y_coor < j else 0
	cuda.syncthreads()

	k = cuda.threadIdx.x
	for d in range(M):
		k *= 2
		i1 = k + 2**d - 1
		i2 = k + 2**(d + 1) - 1
		if i2 >= cuda.blockDim.x:
			break
		sA[i2, cuda.threadIdx.y] += sA[i1, cuda.threadIdx.y]
	cuda.syncthreads()

	if cuda.threadIdx.x == 0:
		d_inter[cuda.blockIdx.z, y_coor, cuda.blockIdx.x] = sA[cuda.blockDim.x - 1, cuda.threadIdx.y]
		sA[cuda.blockDim.x - 1, cuda.threadIdx.y] = 0
	cuda.syncthreads()

	k = 2**(M + 1) * cuda.threadIdx.x
	for d in range(M - 1, -1, -1):
		k //= 2
		i1 = k + 2**d - 1
		i2 = k + 2**(d + 1) - 1
		if i2 >= cuda.blockDim.x:
			continue
		t = sA[i1, cuda.threadIdx.y]
		sA[i1, cuda.threadIdx.y] = sA[i2, cuda.threadIdx.y]
		sA[i2, cuda.threadIdx.y] += t
	cuda.syncthreads()

	if x_coor < n and y_coor < j:
		d_a[cuda.blockIdx.z, y_coor, x_coor] = sA[cuda.threadIdx.x, cuda.threadIdx.y]

@cuda.jit('void(uint32[:, :, :], uint32[:, :, :], uint16, uint16)')
def __add_3d__(d_X: np.ndarray, d_s: np.ndarray, n: int, m: int) -> None:
	"""GPU kernel for parallel sum.

	Args:
		d_X (np.ndarray): Dataset of images.
		d_s (np.ndarray): Temporary sums to add.
		n (int): Number of width blocks.
		m (int): Height of a block.
	"""
	x_coor, y_coor = cuda.grid(2)
	if x_coor < n and y_coor < m:
		d_X[cuda.blockIdx.z, y_coor, x_coor] += d_s[cuda.blockIdx.z, y_coor, cuda.blockIdx.x]

def __scanGPU_3d__(X: np.ndarray) -> np.ndarray:
	"""Parallel Prefix Sum (scan) of a given dataset.

	Read more: https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda

	Args:
		X (np.ndarray): Dataset of images.

	Returns:
		np.ndarray: Scanned dataset of images.
	"""
	k, height, n = X.shape
	n_block_x, n_block_y = np.ceil(np.divide(X.shape[1:], NB_THREADS_2D)).astype(np.uint64)
	d_X = cuda.to_device(X)

	d_inter = cuda.to_device(np.empty((k, height, n_block_x), dtype = np.uint32))
	__kernel_scan_3d__[(n_block_x, n_block_y, k), NB_THREADS_2D](n, height, d_inter, d_X)
	cuda.synchronize()

	inter = d_inter.copy_to_host()
	if n_block_x >= NB_THREADS_2D[0]:
		sums = __scanGPU_3d__(inter)

		d_s = cuda.to_device(sums)
		__add_3d__[(n_block_x, n_block_y, k), NB_THREADS_2D](d_X, d_s, n, height)
		cuda.synchronize()
		X_scan = d_X.copy_to_host()
	else:
		sums = __scanCPU_3d__(inter)
		X_scan = d_X.copy_to_host()

		for p in range(k):
			for h in range(height):
				for i in range(1, n_block_x):
					for j in range(NB_THREADS_2D[1]):
						idx = i * NB_THREADS_2D[1] + j
						if idx < n:
							X_scan[p, h, idx] += sums[p, h, i]

	return X_scan

@cuda.jit('void(uint32[:, :, :], uint32[:, :, :])')
def __transpose_kernel__(d_X: np.ndarray, d_Xt: np.ndarray) -> None:
	"""GPU kernel of the function __transpose_3d__.

	Args:
		d_X (np.ndarray): Dataset of images.
		d_Xt(np.ndarray): Transposed dataset of images.
		width (int): Width of each images in the dataset.
		height (int): Height of each images in the dataset.
	"""
	temp = cuda.shared.array(NB_THREADS_2D, dtype = uint32)

	x, y = cuda.grid(2)
	if x < d_X.shape[1] and y < d_X.shape[2]:
		temp[cuda.threadIdx.y, cuda.threadIdx.x] = d_X[cuda.blockIdx.z, x, y]
	cuda.syncthreads()

	x = cuda.blockIdx.y * cuda.blockDim.y + cuda.threadIdx.x
	y = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.y
	if x < d_X.shape[2] and y < d_X.shape[1]:
		d_Xt[cuda.blockIdx.z, x, y] = temp[cuda.threadIdx.x, cuda.threadIdx.y]

def __transpose_3d__(X: np.ndarray) -> np.ndarray:
	"""Transpose every images in the given dataset.

	Args:
		X (np.ndarray): Dataset of images.

	Returns:
		np.ndarray: Transposed dataset of images.
	"""
	n_block_x, n_block_z = np.ceil(np.divide(X.shape[1:], NB_THREADS_2D)).astype(np.uint64)
	d_X = cuda.to_device(X)
	d_Xt = cuda.to_device(np.empty((X.shape[0], X.shape[2], X.shape[1]), dtype = X.dtype))
	__transpose_kernel__[(n_block_x, n_block_z, X.shape[0]), NB_THREADS_2D](d_X, d_Xt)
	return d_Xt.copy_to_host()

def set_integral_image(X: np.ndarray) -> np.ndarray:
	"""Transform the input images in integrated images (GPU version).

	Args:
		X (np.ndarray): Dataset of images.

	Returns:
		np.ndarray: Dataset of integrated images.
	"""
	X = X.astype(np.uint32)
	X = __scanGPU_3d__(X)
	X = __transpose_3d__(X)
	X = __scanGPU_3d__(X)
	return __transpose_3d__(X)

@cuda.jit('void(int32[:, :], uint8[:], int32[:, :], uint16[:, :], float64[:], float64, float64)')
def __train_weak_clf_kernel__(d_classifiers: np.ndarray, d_y: np.ndarray, d_X_feat: np.ndarray, d_X_feat_argsort: np.ndarray,
							  d_weights: np.ndarray, total_pos: float, total_neg: float) -> None:
	"""GPU kernel of the function train_weak_clf.

	Args:
		d_classifiers (np.ndarray): Weak classifiers to train.
		d_y (np.ndarray): Labels of the features.
		d_X_feat (np.ndarray): Feature images dataset.
		d_X_feat_argsort (np.ndarray): Sorted indexes of the integrated features.
		d_weights (np.ndarray): Weights of the features.
		total_pos (float): Total of positive labels in the dataset.
		total_neg (float): Total of negative labels in the dataset.
	"""
	i = cuda.blockIdx.x * cuda.blockDim.x * cuda.blockDim.y * cuda.blockDim.z
	i += cuda.threadIdx.x * cuda.blockDim.y * cuda.blockDim.z
	i += cuda.threadIdx.y * cuda.blockDim.z
	i += cuda.threadIdx.z
	if i >= d_classifiers.shape[0]:
		return

	pos_seen, neg_seen = 0, 0
	pos_weights, neg_weights = 0.0, 0.0
	min_error, best_threshold, best_polarity = float64(np.inf), 0, 0

	for j in d_X_feat_argsort[i]:
		error = min(neg_weights + total_pos - pos_weights, pos_weights + total_neg - neg_weights)
		if error < min_error:
			min_error = error
			best_threshold = d_X_feat[i, j]
			best_polarity = 1 if pos_seen > neg_seen else -1

		if d_y[j] == 1:
			pos_seen += 1
			pos_weights += d_weights[j]
		else:
			neg_seen += 1
			neg_weights += d_weights[j]

	d_classifiers[i] = (best_threshold, best_polarity)

#@njit('int32[:, :](int32[:, :], uint16[:, :], uint8[:], float64[:])')
def train_weak_clf(X_feat: np.ndarray, X_feat_argsort: np.ndarray, y: np.ndarray, weights: np.ndarray) -> np.ndarray:
	"""Train the weak classifiers on a given dataset (GPU version).

	Args:
		X_feat (np.ndarray): Feature images dataset.
		X_feat_argsort (np.ndarray): Sorted indexes of the integrated features.
		y (np.ndarray): Labels of the features.
		weights (np.ndarray): Weights of the features.

	Returns:
		np.ndarray: Trained weak classifiers.
	"""
	total_pos, total_neg = weights[y == 1].sum(), weights[y == 0].sum()
	d_classifiers = cuda.to_device(np.empty((X_feat.shape[0], 2), dtype = np.int32))
	d_X_feat = cuda.to_device(X_feat)
	d_X_feat_argsort = cuda.to_device(X_feat_argsort)
	d_weights = cuda.to_device(weights)
	d_y = cuda.to_device(y)
	n_blocks = np.ceil(X_feat.shape[0] / np.prod(NB_THREADS_3D)).astype(np.uint16)
	__train_weak_clf_kernel__[n_blocks, NB_THREADS_3D](d_classifiers, d_y, d_X_feat, d_X_feat_argsort, d_weights, total_pos, total_neg)
	return d_classifiers.copy_to_host()

@cuda.jit('uint32(uint32[:, :], int16, int16, int16, int16)', device = True)
def __compute_feature__(ii: np.ndarray, x: int, y: int, w: int, h: int) -> int:
	"""Compute a feature on an integrated image at a specific coordinate (GPU version).

	Args:
		ii (np.ndarray): Integrated image.
		x (int): X coordinate.
		y (int): Y coordinate.
		w (int): width of the feature.
		h (int): height of the feature.

	Returns:
		int: Computed feature.
	"""
	return ii[y + h, x + w] + ii[y, x] - ii[y + h, x] - ii[y, x + w]

@cuda.jit('void(int32[:, :], uint8[:, :, :, :], uint32[:, :, :])')
def __apply_feature_kernel__(X_feat: np.ndarray, feats: np.ndarray, X_ii: np.ndarray) -> None:
	"""GPU kernel of the function apply_features.

	Args:
		X_feat (np.ndarray): Feature images dataset.
		feats (np.ndarray): Features to apply.
		X_ii (np.ndarray): Integrated image dataset.
		n (int): Number of features.
		m (int): Number of images of the dataset.
	"""
	x, y = cuda.grid(2)
	if x >= feats.shape[0] or y >= X_ii.shape[0]:
		return

	p_x, p_y, p_w, p_h = feats[x, 0, 0]
	p1_x, p1_y, p1_w, p1_h = feats[x, 0, 1]
	n_x, n_y, n_w, n_h = feats[x, 1, 0]
	n1_x, n1_y, n1_w, n1_h = feats[x, 1, 1]
	sP = __compute_feature__(X_ii[y], p_x, p_y, p_w, p_h) + \
		__compute_feature__(X_ii[y], p1_x, p1_y, p1_w, p1_h)
	sN = __compute_feature__(X_ii[y], n_x, n_y, n_w, n_h) + \
		__compute_feature__(X_ii[y], n1_x, n1_y, n1_w, n1_h)
	X_feat[x, y] = sP - sN

#@njit('int32[:, :](uint8[:, :, :, :], uint32[:, :, :])')
def apply_features(feats: np.ndarray, X_ii: np.ndarray) -> np.ndarray:
	"""Apply the features on a integrated image dataset (GPU version).

	Args:
		feats (np.ndarray): Features to apply.
		X_ii (np.ndarray): Integrated image dataset.

	Returns:
		np.ndarray: Applied features.
	"""
	d_X_feat = cuda.to_device(np.empty((feats.shape[0], X_ii.shape[0]), dtype = np.int32))
	d_feats = cuda.to_device(feats)
	d_X_ii = cuda.to_device(X_ii)
	n_x_blocks, n_y_blocks = np.ceil(np.divide(d_X_feat.shape, NB_THREADS_2D)).astype(np.uint16)
	__apply_feature_kernel__[(n_x_blocks, n_y_blocks), NB_THREADS_2D](d_X_feat, d_feats, d_X_ii)
	cuda.synchronize()
	return d_X_feat.copy_to_host()

@cuda.jit('int32(int32[:], uint16[:], int32, int32)', device = True)
def as_partition(a: np.ndarray, indices: np.ndarray, l: int, h: int) -> int:
	i = l - 1
	j = l
	for j in range(l, h + 1):
		if a[indices[j]] < a[indices[h]]:
			i += 1
			indices[i], indices[j] = indices[j], indices[i]

	i += 1
	indices[i], indices[j] = indices[j], indices[i]
	return i

@cuda.jit('void(int32[:], uint16[:], int32, int32)', device = True)
def argsort_bounded(a: np.ndarray, indices: np.ndarray, l: int, h: int) -> None:
	#total = h - l + 1;
	stack = cuda.local.array(6977, int32)
	stack[0] = l
	stack[1] = h
	top = 1;

	low = l
	high = h

	while top >= 0:
		high = stack[top]
		top -= 1
		low = stack[top]
		top -= 1

		if low >= high:
			break;

		p = as_partition(a, indices, low, high);

		if p - 1 > low:
			top += 1
			stack[top] = low;
			top += 1
			stack[top] = p - 1;

		if p + 1 < high:
			top += 1
			stack[top] = p + 1;
			top += 1
			stack[top] = high;

@cuda.jit('void(int32[:, :], uint16[:, :])')
def argsort_flatter(X_feat: np.ndarray, indices: np.ndarray) -> None:
	i = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
	if i < X_feat.shape[0]:
		for j in range(indices.shape[1]):
			indices[i, j] = j
		argsort_bounded(X_feat[i], indices[i], 0, X_feat.shape[1] - 1)

def argsort(X_feat: np.ndarray) -> np.ndarray:
	indices = np.empty_like(X_feat, dtype = np.uint16)
	n_blocks = int(np.ceil(np.divide(X_feat.shape[0], NB_THREADS)))
	d_X_feat = cuda.to_device(X_feat)
	d_indices = cuda.to_device(indices)
	argsort_flatter[n_blocks, NB_THREADS](d_X_feat, d_indices)
	cuda.synchronize()
	return d_indices.copy_to_host()