from io import BufferedReader from tqdm import tqdm from functools import partial from sys import argv import numpy as np from numpy.typing import NDArray from typing import Final, Callable from os import path, listdir # Induce determinism np.random.seed(196_863) # Makes the 'leave' argument default to False tqdm: Callable = partial(tqdm, leave = False) def read_pgm(pgm_file: BufferedReader) -> NDArray[np.uint8]: """Read the data of a PGM file Args: pgm_file (BufferedReader): PGM File Returns: NDArray[np.uint8]: PGM data """ assert (f := pgm_file.readline()) == b'P5\n', f'Incorrect file format: {f}' (width, height) = (int(i) for i in pgm_file.readline().split()) assert width > 0 and height > 0, f'Incorrect dimensions: {width}x{height}' assert (depth := int(pgm_file.readline())) < 256, f'Incorrect depth: {depth}' buff: Final[NDArray[np.uint8]] = np.empty(height * width, dtype = np.uint8) for i in range(buff.shape[0]): buff[i] = ord(pgm_file.read(1)) return buff.reshape((height, width)) def __main__(data_path: str) -> None: """Read the data of every PGM file and output it in data files Args: data_path (str): Path of the PGM files """ for set_name in tqdm(['train', 'test'], desc = 'set name'): X, y = [], [] for y_i, label in enumerate(tqdm(['non-face', 'face'], desc = 'label')): for filename in tqdm(listdir(f'{data_path}/{set_name}/{label}'), desc = 'Reading pgm file'): with open(f'{data_path}/{set_name}/{label}/{filename}', 'rb') as face: X.append(read_pgm(face)) y.append(y_i) X, y = np.asarray(X), np.asarray(y) idx: NDArray[np.int64] = np.random.permutation(y.shape[0]) X, y = X[idx], y[idx] for org, s in tqdm(zip('Xy', [X, y]), desc = f'Writing {set_name}'): with open(f'{data_path}/{org}_{set_name}.bin', 'w') as out: out.write(f'{str(s.shape)[1:-1].replace(',', '')}\n') raw: NDArray = s.ravel() for s_i in tqdm(raw[:-1], desc = f'Writing {org}'): out.write(f'{s_i} ') out.write(str(raw[-1])) if __name__ == '__main__': __main__(argv[1]) if len(argv) == 2 else print(f'Usage: python {__file__[__file__.rfind(path.sep) + 1:]} ./data_location')