#include "data.hpp"
//#include "toolbox.hpp"
//#include <cstring>

int print(const np::Shape& shape) noexcept {
	int num_written = 0;
	num_written += printf("(");
	if (shape.length > 1) {
		const size_t length = shape.length - 1;
		for (size_t i = 0; i < length; ++i)
			num_written += printf("%lu, ", shape[i]);
		num_written += printf("%lu)\n", shape[length]);
	}
	else
		num_written += printf("%lu,)\n", shape[0]);
	return num_written;
}

template<typename T>
int print(const np::Array<T>& array, const char* format) noexcept {
	//printf("[");
	//const size_t length = np::prod(array.shape);
	//for(size_t i = 0; i < length - 1; ++i)
	//	//std::cout << array[i] << " ";
	//	printf("%f ", array[i]);
	////std::cout << array[array.shape[0] - 1] << "]\n";
	//printf("%f]\n", array[length - 1]);

	char format_space[BUFFER_SIZE] = { 0 };
	sprintf(format_space, "%s ", format);
	char format_close[BUFFER_SIZE] = { 0 };
	sprintf(format_close, "%s]\n", format);
	int num_written = 0;

	if (array.shape.length == 1) {
		const size_t max = array.shape[0] - 1;
		num_written += printf("[");
		for (size_t i = 0; i < max; ++i)
			num_written += printf(format_space, array[i]);
		num_written += printf(format_close, array[max]);
	}
	else {
		num_written += printf("[");
		for (size_t i = 0; i < array.shape[0]; ++i) {
			num_written += printf(" [");
			for (size_t j = 0; j < array.shape[1] - 1; ++j)
				num_written += printf(format_space, array[i * array.shape[1] + j]);
			num_written += printf(format_close, array[i * array.shape[1] + array.shape[1] - 1]);
		}
		num_written += printf("]\n");
	}

	return num_written;
}

int print(const np::Array<uint8_t>& array) noexcept {
	return print(array, "%hu");
}

int print(const np::Array<float64_t>& array) noexcept {
	return print(array, "%f");
}

int print_feat(const np::Array<uint8_t>& array, const np::Slice& slice) noexcept {
	int num_written = 0;
	num_written += printf("[");
	const size_t feat_size = np::prod(array.shape, 1);
	const size_t offset = slice.x * feat_size;
	const size_t length = offset + feat_size - 1;
	for (size_t i = offset; i < length; ++i)
		num_written += printf("%2hu ", array[i]);
	num_written += printf("%2hu]\n", array[length]);

	return num_written;
}

int print(const np::Array<uint8_t>& array, const np::Slice& slice) noexcept {
	int num_written = 0;
	if (array.shape.length == 1) {
		const size_t max = slice.y - 1; //std::min(slice.y, array.shape[0] - 1);
		num_written += printf("[");
		for (size_t i = slice.x; i < max; ++i)
			num_written += printf("%hu ", array[i]);
		num_written += printf("%hu]\n", array[max]);
	}
	else {
		num_written += printf("[");
		size_t k = slice.x * array.shape[1] * array.shape[2] + slice.y * array.shape[2] + slice.z;
		for (size_t i = 0; i < array.shape[1]; ++i) {
			num_written += printf(" [");
			for (size_t j = 0; j < array.shape[2]; ++j)
				num_written += printf("%3hu ", array[k + i * array.shape[1] + j]);
			num_written += printf("]\n");
		}
		num_written += printf("]\n");
	}
	return num_written;
}

int print(const np::Array<uint32_t>& array, const np::Slice& slice) noexcept {
	int num_written = 0;
	if (array.shape.length == 1) {
		const size_t max = slice.y - 1; //std::min(slice.y, array.shape[0] - 1);
		num_written += printf("[");
		for (size_t i = slice.x; i < max; ++i)
			num_written += printf("%iu ", array[i]);
		num_written += printf("%iu]\n", array[max]);
	}
	else {
		num_written += printf("[");
		size_t k = slice.x * array.shape[1] * array.shape[2] + slice.y * array.shape[2] + slice.z;
		for (size_t i = 0; i < array.shape[1]; ++i) {
			num_written += printf(" [");
			for (size_t j = 0; j < array.shape[2]; ++j)
				num_written += printf("%5i ", array[k + i * array.shape[1] + j]);
			num_written += printf("]\n");
		}
		num_written += print("]");
	}
	return num_written;
}

int print(const np::Array<int32_t>& array, const np::Slice& slice) noexcept {
	int num_written = 0;
	num_written += printf("[");
	//size_t k = slice.x * array.shape[1] * array.shape[2] + slice.y * array.shape[2] + slice.z;
	size_t k = slice.x * array.shape[1];
	for (size_t i = k; i < k + (slice.y - slice.x); ++i) {
		num_written += printf("%5i ", array[i]);
	}
	num_written += print("]");
	return num_written;
}

int print(const np::Array<uint16_t>& array, const np::Slice& slice) noexcept {
	int num_written = 0;
	num_written += printf("[");
	//size_t k = slice.x * array.shape[1] * array.shape[2] + slice.y * array.shape[2] + slice.z;
	size_t k = slice.x * array.shape[1];
	for (size_t i = k; i < k + (slice.y - slice.x); ++i) {
		num_written += printf("%5hu ", array[i]);
	}
	num_written += print("]");
	return num_written;
}

static inline np::Array<uint8_t> load_set(const char* set_name) {
	FILE* file = fopen(set_name, "rb");
	if (file == NULL) {
		print_error_file(set_name);
		throw;
	}
	char meta[BUFFER_SIZE];
	if (!fgets(meta, BUFFER_SIZE, file)) {
		print_error_file(set_name);
		fclose(file);
		throw;
	}
	size_t* dims = new size_t[3]();
	if (!sscanf(meta, "%lu %lu %lu", &dims[0], &dims[1], &dims[2])) {
		print_error_file(set_name);
		fclose(file);
		throw;
	}
	np::Shape shape = { static_cast<size_t>(dims[1] == 0 ? 1 : 3), dims };
	np::Array<uint8_t> a = np::empty<uint8_t>(std::move(shape));

	const size_t size = np::prod(a.shape);
	size_t i = 0, j = 0;
	int c;
	char buff[STRING_INT_SIZE] = { 0 };
	while ((c = fgetc(file)) != EOF && i < size) {
		if (c == ' ' || c == '\n') {
			buff[j] = '\0';
			a[i++] = static_cast<uint8_t>(atoi(buff));
			//memset(buff, 0, STRING_INT_SIZE);
			j = 0;
		}
		else
			buff[j++] = (char)c;
	}
	buff[j] = '\0';
	a[i++] = static_cast<uint8_t>(atoi(buff));
	if (i != size) {
		fprintf(stderr, "Missing loaded data %lu/%lu\n", i, size);
		fclose(file);
		throw;
	}
	fclose(file);

	return a;
}

std::array<np::Array<uint8_t>, 4> load_datasets() {
	return {
		load_set(DATA_DIR "/X_train.bin"), load_set(DATA_DIR "/y_train.bin"),
		load_set(DATA_DIR "/X_test.bin"), load_set(DATA_DIR "/y_test.bin")
	};
}

void print_error_file(const char* file_dir) noexcept {
	const char* buff = strerror(errno);
	fprintf(stderr, "Can't open %s, error code = %d : %s\n", file_dir, errno, buff);
	// delete buff;
}

//size_t np::prod(const np::Shape& shape, const size_t& offset) noexcept {
//	size_t result = shape[offset];
//	for(size_t i = 1 + offset; i < shape.length; ++i)
//		result *= shape[i];
//	return result;
//}