#include "data.hpp"
#include "config.hpp"

#if GPU_BOOSTED

/**
 * @brief Prefix Sum (scan) of a given dataset.
 *
 * @param X Dataset of images to apply sum
 * @return Scanned dataset of images
 */
static np::Array<uint32_t> __scanCPU_3d__(const np::Array<uint32_t>& X) noexcept {
	np::Array<uint32_t> X_scan = np::empty<uint32_t>(X.shape);
	const size_t total = np::prod(X_scan.shape);
	const size_t i_step = np::prod(X_scan.shape, 1);
	for(size_t x = 0; x < total; x += i_step)
		for(size_t y = 0; y < i_step; y += X_scan.shape[2]){
			uint32_t cum = 0;
			for(size_t z = 0; z < X_scan.shape[2]; ++z){
				const size_t idx = x + y + z;
				cum += X[idx];
				X_scan[idx] = cum - X[idx];
			}
		}
	return X_scan;
}

/**
 * @brief GPU kernel used to do a parallel prefix sum (scan).
 *
 * @param n Number of width blocks
 * @param j Temporary sum index
 * @param d_inter Temporary sums on device to add
 * @param d_X Dataset of images on device to apply sum
 */
static __global__ void __kernel_scan_3d__(const uint16_t n, const uint16_t j, np::Array<uint32_t> d_inter, np::Array<uint32_t> d_X) {
	const size_t x_coor = blockIdx.x * blockDim.x + threadIdx.x;
	const size_t y_coor = blockIdx.y * blockDim.y + threadIdx.y;

	__shared__ uint32_t sA[NB_THREADS_2D_X * NB_THREADS_2D_Y];
	sA[threadIdx.x * NB_THREADS_2D_Y + threadIdx.y] = (x_coor < n && y_coor) < j ?
		d_X[blockIdx.z * NB_THREADS_2D_X * NB_THREADS_2D_Y + y_coor * NB_THREADS_2D_Y + x_coor] : 0;
	__syncthreads();

	size_t k = threadIdx.x;
	for(size_t d = 0; d < M; ++d){
		k *= 2;
		const size_t i1 = k + std::pow(2, d) - 1;
		const size_t i2 = k + std::pow(2, d + 1) - 1;
		if(i2 >= blockDim.x)
			break;
		sA[i2 * NB_THREADS_2D_Y + threadIdx.y] += sA[i1 * NB_THREADS_2D_Y + threadIdx.y];
	}
	__syncthreads();

	if(threadIdx.x == 0){
		d_inter[blockIdx.z * d_inter.shape[1] * d_inter.shape[2] + y_coor * d_inter.shape[2] + blockIdx.x] =
			sA[(blockDim.x - 1) * NB_THREADS_2D_Y + threadIdx.y];
		sA[(blockDim.x - 1) * NB_THREADS_2D_Y + threadIdx.y] = 0;
	}
	__syncthreads();

	k = std::pow(2, M + 1) * threadIdx.x;
	for(int64_t d = M - 1; d > -1; --d){
		k =  k / 2;
		const size_t i1 = k + std::pow(2, d) - 1;
		const size_t i2 = k + std::pow(2, d + 1) - 1;
		if(i2 >= blockDim.x)
			continue;
		const uint32_t t = sA[i1 * NB_THREADS_2D_Y + threadIdx.y];
		sA[i1 * NB_THREADS_2D_Y + threadIdx.y]= sA[i2 * NB_THREADS_2D_Y + threadIdx.y];
		sA[i2 * NB_THREADS_2D_Y + threadIdx.y] += t;
	}
	__syncthreads();

	if(x_coor < n && y_coor < j)
		d_X[blockIdx.z * d_X.shape[1] * d_X.shape[2] + y_coor * d_X.shape[2] + x_coor] = sA[threadIdx.x * NB_THREADS_2D_Y + threadIdx.y];
}

/**
 * @brief GPU kernel for parallel sum.
 *
 * @param d_X Dataset of images on device
 * @param d_s Temporary sums to add on device
 * @param n Number of width blocks
 * @param m Height of a block
 */
static __global__ void __add_3d__(np::Array<uint32_t> d_X, const np::Array<uint32_t> d_s, const uint16_t n, const uint16_t m) {
	const size_t x_coor = blockIdx.x * blockDim.x + threadIdx.x;
	const size_t y_coor = blockIdx.y * blockDim.y + threadIdx.y;
	if(x_coor < n && y_coor < m)
		d_X[blockIdx.z * d_X.shape[1] * d_X.shape[2] + y_coor * d_X.shape[2] + x_coor] += d_s[blockIdx.z * d_X.shape[1] * d_X.shape[2] + y_coor * d_X.shape[2] + blockIdx.x];
}

/**
 * @brief Parallel Prefix Sum (scan) of a given dataset.
 *
 * Read more: https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda
 *
 * @param X Dataset of images
 * @return Scanned dataset of images
 */
static np::Array<uint32_t> __scanGPU_3d__(const np::Array<uint32_t>& X) noexcept {
	np::Array<uint32_t> X_scan = np::empty<uint32_t>(X.shape);

	const size_t k = X.shape[0];
	const size_t height = X.shape[1];
	const size_t n = X.shape[2];
	const size_t n_block_x = static_cast<size_t>(std::ceil(static_cast<float64_t>(X.shape[1]) / static_cast<float64_t>(NB_THREADS_2D_X)));
	const size_t n_block_y = static_cast<size_t>(std::ceil(static_cast<float64_t>(X.shape[2]) / static_cast<float64_t>(NB_THREADS_2D_Y)));

	np::Array<uint32_t> d_X = copyToDevice<uint32_t>("X", X);
	np::Array<uint32_t> inter = np::empty<uint32_t>({ k, height, n_block_x });
	np::Array<uint32_t> d_inter = copyToDevice<uint32_t>("inter", inter);

	const dim3 dimGrid(n_block_x, n_block_y, k);
	constexpr const dim3 dimBlock(NB_THREADS_2D_X, NB_THREADS_2D_Y);
	__kernel_scan_3d__<<<dimGrid, dimBlock>>>(n, height, d_inter, d_X);
	_print_cuda_error_("synchronize", cudaDeviceSynchronize());

	_print_cuda_error_("memcpy d_inter", cudaMemcpy(inter.data, d_inter.data, np::prod(inter.shape) * sizeof(uint32_t), cudaMemcpyDeviceToHost));
	if(n_block_x >= NB_THREADS_2D_X){
		np::Array<uint32_t> sums = __scanGPU_3d__(inter);

		np::Array<uint32_t> d_s = copyToDevice<uint32_t>("sums", sums);
		__add_3d__<<<dimGrid, dimBlock>>>(d_X, d_s, n, height);
		_print_cuda_error_("synchronize", cudaDeviceSynchronize());
		_print_cuda_error_("memcpy d_X", cudaMemcpy(X_scan.data, d_X.data, np::prod(X_scan.shape) * sizeof(uint32_t), cudaMemcpyDeviceToHost));
	} else {
		np::Array<uint32_t> sums = __scanCPU_3d__(inter);
		_print_cuda_error_("memcpy d_X", cudaMemcpy(X_scan.data, d_X.data, np::prod(X_scan.shape) * sizeof(uint32_t), cudaMemcpyDeviceToHost));

		for(size_t p = 0; p < k; ++p)
			for(size_t h = 0; h < height; ++h)
				for(size_t i = 1; i < n_block_x; ++i)
					for(size_t j = 0; j < NB_THREADS_2D_X; ++j){
						const size_t idx = i * NB_THREADS_2D_X + j;
						if(idx < n){
							const size_t idy = p * X_scan.shape[1] * X_scan.shape[2] + h * X_scan.shape[2];
							X_scan[idy + idx] += sums[idy + i];
						}
					}
	}

 	return X_scan;
}

/**
 * @brief GPU kernel of the function __transpose_3d__.
 *
 * @param d_X Dataset of images on device
 * @param d_Xt Transposed dataset of images on device
 */
static __global__ void __transpose_kernel__(const np::Array<uint32_t> d_X, np::Array<uint32_t> d_Xt) {
	__shared__ uint32_t temp[NB_THREADS_2D_X * NB_THREADS_2D_Y];

	size_t x = blockIdx.x * blockDim.x + threadIdx.x;
	size_t y = blockIdx.y * blockDim.y + threadIdx.y;
	if(x < d_X.shape[1] && y < d_X.shape[2])
		temp[threadIdx.y * NB_THREADS_2D_Y + threadIdx.x] = d_X[blockIdx.z * d_X.shape[1] * d_X.shape[2] + x * d_X.shape[2] + y];

	__syncthreads();

	x = blockIdx.y * blockDim.y + threadIdx.x;
	y = blockIdx.x * blockDim.x + threadIdx.y;
	if(x < d_X.shape[2] && y < d_X.shape[1])
		d_Xt[blockIdx.z * d_Xt.shape[1] * d_Xt.shape[2] + x * d_X.shape[2] + y] = temp[threadIdx.x * NB_THREADS_2D_Y + threadIdx.y];
}

/**
 * @brief Transpose every images in the given dataset.
 *
 * @param X Dataset of images
 * @return Transposed dataset of images
 */
static np::Array<uint32_t> __transpose_3d__(const np::Array<uint32_t>& X) noexcept {
	np::Array<uint32_t> Xt = np::empty<uint32_t>({ X.shape[0], X.shape[2], X.shape[1] });

	np::Array<uint32_t> d_X = copyToDevice<uint32_t>("X", X);
	np::Array<uint32_t> d_Xt = copyToDevice<uint32_t>("Xt", Xt);

	const size_t n_block_x = static_cast<size_t>(std::ceil(static_cast<float64_t>(X.shape[1]) / static_cast<float64_t>(NB_THREADS_2D_X)));
	const size_t n_block_y = static_cast<size_t>(std::ceil(static_cast<float64_t>(X.shape[2]) / static_cast<float64_t>(NB_THREADS_2D_Y)));
	const dim3 dimGrid(n_block_x, n_block_y, X.shape[0]);
	constexpr const dim3 dimBlock(NB_THREADS_2D_X, NB_THREADS_2D_Y);
	__transpose_kernel__<<<dimGrid, dimBlock>>>(d_X, d_Xt);
	_print_cuda_error_("synchronize", cudaDeviceSynchronize());
	_print_cuda_error_("memcpy d_Xt", cudaMemcpy(Xt.data, d_Xt.data, np::prod(Xt.shape) * sizeof(uint32_t), cudaMemcpyDeviceToHost));
	cudaFree("X", d_X);
	cudaFree("Xt", d_Xt);

	return Xt;
}

/**
 * @brief Transform the input images in integrated images (GPU version).
 *
 * @param X Dataset of images
 * @return Dataset of integrated images
 */
np::Array<uint32_t> set_integral_image(const np::Array<uint8_t>& X) noexcept {
	np::Array<uint32_t> X_ii = np::astype<uint32_t>(X);
	X_ii = __scanCPU_3d__(X_ii);
	X_ii = __transpose_3d__(X_ii);
	X_ii = __scanCPU_3d__(X_ii);
	return __transpose_3d__(X_ii);
}

/**
 * @brief GPU kernel of the function train_weak_clf.
 *
 * @param d_classifiers Weak classifiers on device to train
 * @param d_y Labels of the features on device
 * @param d_X_feat Feature images dataset on device
 * @param d_X_feat_argsort Sorted indexes of the integrated features on device
 * @param d_weights Weights of the features on device
 * @param total_pos Total of positive labels in the dataset
 * @param total_neg Total of negative labels in the dataset
 */
static __global__ void __train_weak_clf_kernel__(np::Array<float64_t> d_classifiers, const np::Array<uint8_t> d_y,
						const np::Array<int32_t> d_X_feat, const np::Array<uint16_t> d_X_feat_argsort,
						const np::Array<float64_t> d_weights, const float64_t total_pos, const float64_t total_neg) {

	size_t i = blockIdx.x * blockDim.x * blockDim.y * blockDim.z;
	i += threadIdx.x * blockDim.y * blockDim.z;
	i += threadIdx.y * blockDim.z;
	i += threadIdx.z;

	if(i >= d_classifiers.shape[0])
		return;

	size_t pos_seen = 0, neg_seen = 0;
	float64_t pos_weights = 0.0, neg_weights = 0.0;
	float64_t min_error = np::inf, best_threshold = 0.0, best_polarity = 0.0;
	for(size_t j = 0; j < d_X_feat_argsort.shape[1]; ++j) {
		const float64_t error = np::min(neg_weights + total_pos - pos_weights, pos_weights + total_neg - neg_weights);
		if (error < min_error){
			min_error = error;
			best_threshold = d_X_feat[i * d_X_feat.shape[1] + d_X_feat_argsort[i * d_X_feat.shape[1] + j]];
			best_polarity = pos_seen > neg_seen ? 1.0 : -1.0;
		}
		if(d_y[d_X_feat_argsort[i * d_X_feat.shape[1] + j]] == static_cast<uint8_t>(1)){
			++pos_seen;
			pos_weights += d_weights[d_X_feat_argsort[i * d_X_feat.shape[1] + j]];
		} else {
			++neg_seen;
			neg_weights += d_weights[d_X_feat_argsort[i * d_X_feat.shape[1] + j]];
		}
	}
	d_classifiers[i * 2] = best_threshold; d_classifiers[i * 2 + 1] = best_polarity;
}

/**
 * @brief Train the weak classifiers on a given dataset (GPU version).
 *
 * @param X_feat Feature images dataset
 * @param X_feat_argsort Sorted indexes of the integrated features
 * @param y Labels of the features
 * @param weights Weights of the features
 * @return Trained weak classifiers
 */
np::Array<float64_t> train_weak_clf(const np::Array<int32_t>& X_feat, const np::Array<uint16_t>& X_feat_argsort, const np::Array<uint8_t>& y,
					const np::Array<float64_t>& weights) noexcept {
	float64_t total_pos = 0.0, total_neg = 0.0;
	for(size_t i = 0; i < y.shape[0]; ++i)
		(y[i] == static_cast<uint8_t>(1) ? total_pos : total_neg) += weights[i];

	np::Array<float64_t> classifiers = np::empty<float64_t>({ X_feat.shape[0], 2});

	np::Array<float64_t> d_classifiers = copyToDevice<float64_t>("classifiers", classifiers);
	np::Array<int32_t> d_X_feat = copyToDevice<int32_t>("X_feat", X_feat);
	np::Array<uint16_t> d_X_feat_argsort = copyToDevice<uint16_t>("X_feat_argsort", X_feat_argsort);
	np::Array<float64_t> d_weights = copyToDevice<float64_t>("weights", weights);
	np::Array<uint8_t> d_y = copyToDevice<uint8_t>("y", y);

	const size_t n_blocks = static_cast<size_t>(std::ceil(static_cast<float64_t>(X_feat.shape[0]) / static_cast<float64_t>(NB_THREADS_3D_X * NB_THREADS_3D_Y * NB_THREADS_3D_Z)));
	constexpr const dim3 dimBlock(NB_THREADS_3D_X, NB_THREADS_3D_Y, NB_THREADS_3D_Z);
	__train_weak_clf_kernel__<<<n_blocks, dimBlock>>>(d_classifiers, d_y, d_X_feat, d_X_feat_argsort, d_weights, total_pos, total_neg);
	_print_cuda_error_("synchronize", cudaDeviceSynchronize());

	_print_cuda_error_("memcpy classifiers", cudaMemcpy(classifiers.data, d_classifiers.data, np::prod(classifiers.shape) * sizeof(float64_t), cudaMemcpyDeviceToHost));

	cudaFree("free d_classifiers", d_classifiers);
	cudaFree("free d_X_feat", d_X_feat);
	cudaFree("free d_X_feat_argsort", d_X_feat_argsort);
	cudaFree("free d_weights", d_weights);
	cudaFree("free d_y", d_y);

	return classifiers;
}

/**
 * @brief Compute a feature on a integrated image at a specific coordinate (GPU version).
 *
 * @param d_X_ii Dataset of integrated images on device
 * @param j Image index in the dataset
 * @param x X coordinate
 * @param y Y coordinate
 * @param w width of the feature
 * @param h height of the feature
 */
static inline __device__ int16_t __compute_feature__(const np::Array<uint32_t>& d_X_ii, const size_t& j, const int16_t& x, const int16_t& y, const int16_t& w, const int16_t& h) noexcept {
	const size_t _y = y * d_X_ii.shape[1] + x;
	const size_t _yh = _y + h * d_X_ii.shape[1];
	return d_X_ii[j + _yh + w] + d_X_ii[j + _y] - d_X_ii[j + _yh] - d_X_ii[j + _y + w];
}

/**
 * @brief GPU kernel of the function apply_features.
 *
 * @param d_X_feat Dataset of image features on device
 * @param d_feats Features on device to apply
 * @param d_X_ii Integrated image dataset on device
 */
static __global__ void __apply_feature_kernel__(int32_t* d_X_feat, const np::Array<uint8_t> d_feats, const np::Array<uint32_t> d_X_ii) {
	size_t i = blockIdx.x * blockDim.x + threadIdx.x;
	size_t j = blockIdx.y * blockDim.y + threadIdx.y;

	if (i >= d_feats.shape[0] || j >= d_X_ii.shape[0])
		return;

	const size_t k = i * d_X_ii.shape[0] + j;
	i *= np::prod(d_feats.shape, 1);
	j *= np::prod(d_X_ii.shape, 1);
	const int16_t p1 = __compute_feature__(d_X_ii, j, d_feats[i +  0], d_feats[i +  1], d_feats[i +  2], d_feats[i +  3]);
	const int16_t p2 = __compute_feature__(d_X_ii, j, d_feats[i +  4], d_feats[i +  5], d_feats[i +  6], d_feats[i +  7]);
	const int16_t n1 = __compute_feature__(d_X_ii, j, d_feats[i +  8], d_feats[i +  9], d_feats[i + 10], d_feats[i + 11]);
	const int16_t n2 = __compute_feature__(d_X_ii, j, d_feats[i + 12], d_feats[i + 13], d_feats[i + 14], d_feats[i + 15]);
	d_X_feat[k] = static_cast<int32_t>(p1 + p2) - static_cast<int32_t>(n1 + n2);
}

/**
 * @brief Apply the features on a integrated image dataset (GPU version).
 *
 * @param feats Features to apply
 * @param X_ii Integrated image dataset
 * @return Applied features
 */
np::Array<int32_t> apply_features(const np::Array<uint8_t>& feats, const np::Array<uint32_t>& X_ii) noexcept {
	const np::Array<int32_t> X_feat = np::empty<int32_t>({ feats.shape[0], X_ii.shape[0] });
	int32_t* d_X_feat = nullptr;

	_print_cuda_error_("malloc d_X_feat", cudaMalloc(&d_X_feat, np::prod(X_feat.shape) * sizeof(int32_t)));
	np::Array<uint32_t> d_X_ii = copyToDevice<uint32_t>("X_ii", X_ii);
	np::Array<uint8_t> d_feats = copyToDevice<uint8_t>("feats", feats);

	const size_t dimX = static_cast<size_t>(std::ceil(static_cast<float64_t>(feats.shape[0]) / static_cast<float64_t>(NB_THREADS_2D_X)));
	const size_t dimY = static_cast<size_t>(std::ceil(static_cast<float64_t>(X_ii.shape[0]) / static_cast<float64_t>(NB_THREADS_2D_Y)));
	const dim3 dimGrid(dimX, dimY);
	constexpr const dim3 dimBlock(NB_THREADS_2D_X, NB_THREADS_2D_Y);
	__apply_feature_kernel__<<<dimGrid, dimBlock>>>(d_X_feat, d_feats, d_X_ii);
	_print_cuda_error_("synchronize", cudaDeviceSynchronize());

	_print_cuda_error_("memcpy X_feat", cudaMemcpy(X_feat.data, d_X_feat, np::prod(X_feat.shape) * sizeof(int32_t), cudaMemcpyDeviceToHost));

	_print_cuda_error_("free d_X_feat", cudaFree(d_X_feat));
	cudaFree("free d_feats", d_feats);
	cudaFree("free d_X_11", d_X_ii);

	return X_feat;
}

/**
 * @brief Partition of the argsort algorithm.
 *
 * @tparam T Inner type of the array
 * @param d_a Array on device to sort
 * @param d_indices Array of indices on device to write to
 * @param low lower bound to sort
 * @param high higher bound to sort
 * @return Last index sorted
 */
template<typename T>
__device__ inline static int32_t _as_partition_(const T* d_a, uint16_t* const d_indices, const size_t low, const size_t high) noexcept {
	int32_t i = low - 1;
	for (int32_t j = low; j <= high; ++j)
		if (d_a[d_indices[j]] < d_a[d_indices[high]])
			swap(&d_indices[++i], &d_indices[j]);
	swap(&d_indices[++i], &d_indices[high]);
	return i;
}

/**
 * @brief Cuda kernel to perform an indirect sort of a given array within a given bound.
 *
 * @tparam T Inner type of the array
 * @param d_a Array on device to sort
 * @param d_indices Array of indices on device to write to
 * @param low lower bound to sort
 * @param high higher bound to sort
 */
template<typename T>
__device__ void argsort_kernel(const T* d_a, uint16_t* const d_indices, size_t low, size_t high) noexcept {
	const size_t total = high - low + 1;

	//int32_t* stack = new int32_t[total]{low, high};
	//int32_t stack[total];
	int32_t stack[6977];
	//int32_t stack[1<<16];
	stack[0] = low;
	stack[1] = high;

	size_t top = 1;

	while (top <= total) {
		high = stack[top--];
		low = stack[top--];
		if(low >= high)
			break;

		const int32_t p = _as_partition_(d_a, d_indices, low, high);

		if (p - 1 > low && p - 1 < total) {
			stack[++top] = low;
			stack[++top] = p - 1;
		}

		if (p + 1 < high) {
			stack[++top] = p + 1;
			stack[++top] = high;
		}
	}
	//delete[] stack;
}

/**
 * @brief Cuda kernel where argsort is applied to every column of a given 2D array.
 *
 * @tparam T Inner type of the array
 * @param d_a 2D Array on device to sort
 * @param d_indices 2D Array of indices on device to write to
 */
template<typename T>
__global__ void argsort_bounded(const np::Array<T> d_a, uint16_t* const d_indices){
	const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
	if (idx >= d_a.shape[0])
		return;

	for(size_t y = 0; y < d_a.shape[1]; ++y) d_indices[idx * d_a.shape[1] + y] = y;
	argsort_kernel(&d_a[idx * d_a.shape[1]], &d_indices[idx * d_a.shape[1]], 0, d_a.shape[1] - 1);
}

/**
 * @brief Perform an indirect sort on each column of a given 2D array (GPU version).
 *
 * @param a 2D Array to sort
 * @return 2D Array of indices that sort the array
 */
np::Array<uint16_t> argsort_2d(const np::Array<int32_t>& a) noexcept {
	const np::Array<uint16_t> indices = np::empty<uint16_t>(a.shape);

	uint16_t* d_indices = nullptr;
	const size_t indices_size = np::prod(indices.shape) * sizeof(uint16_t);

	np::Array<int32_t> d_a = copyToDevice<int32_t>("X_feat", a);
	_print_cuda_error_("malloc d_indices", cudaMalloc(&d_indices, indices_size));

	const size_t dimGrid = static_cast<size_t>(std::ceil(static_cast<float64_t>(a.shape[0]) / static_cast<float64_t>(NB_THREADS)));
	const dim3 dimBlock(NB_THREADS);
	argsort_bounded<<<dimGrid, dimBlock>>>(d_a, d_indices);
	_print_cuda_error_("synchronize", cudaDeviceSynchronize());

	_print_cuda_error_("memcpy d_indices", cudaMemcpy(indices.data, d_indices, indices_size, cudaMemcpyDeviceToHost));

	cudaFree("free d_a", d_a);
	_print_cuda_error_("free d_indices", cudaFree(d_indices));

	return indices;
}

#endif // GPU_BOOSTED