ViolaJones/downloader/convert_dataset.py

63 lines
2.2 KiB
Python

from io import BufferedReader
from tqdm import tqdm
from functools import partial
from sys import argv
import numpy as np
from numpy.typing import NDArray
from typing import Final, Callable
from os import path, listdir
# Induce determinism
np.random.seed(196_863)
# Makes the 'leave' argument default to False
tqdm: Callable = partial(tqdm, leave = False)
def read_pgm(pgm_file: BufferedReader) -> NDArray[np.uint8]:
"""Read the data of a PGM file
Args:
pgm_file (BufferedReader): PGM File
Returns:
NDArray[np.uint8]: PGM data
"""
assert (f := pgm_file.readline()) == b'P5\n', f'Incorrect file format: {f}'
(width, height) = (int(i) for i in pgm_file.readline().split())
assert width > 0 and height > 0, f'Incorrect dimensions: {width}x{height}'
assert (depth := int(pgm_file.readline())) < 256, f'Incorrect depth: {depth}'
buff: Final[NDArray[np.uint8]] = np.empty(height * width, dtype = np.uint8)
for i in range(buff.shape[0]):
buff[i] = ord(pgm_file.read(1))
return buff.reshape((height, width))
def __main__(data_path: str) -> None:
"""Read the data of every PGM file and output it in data files
Args:
data_path (str): Path of the PGM files
"""
for set_name in tqdm(['train', 'test'], desc = 'set name'):
X, y = [], []
for y_i, label in enumerate(tqdm(['non-face', 'face'], desc = 'label')):
for filename in tqdm(listdir(f'{data_path}/{set_name}/{label}'), desc = 'Reading pgm file'):
with open(f'{data_path}/{set_name}/{label}/{filename}', 'rb') as face:
X.append(read_pgm(face))
y.append(y_i)
X, y = np.asarray(X), np.asarray(y)
idx: NDArray[np.int64] = np.random.permutation(y.shape[0])
X, y = X[idx], y[idx]
for org, s in tqdm(zip('Xy', [X, y]), desc = f'Writing {set_name}'):
with open(f'{data_path}/{org}_{set_name}.bin', 'w') as out:
out.write(f'{str(s.shape)[1:-1].replace(',', '')}\n')
raw: NDArray = s.ravel()
for s_i in tqdm(raw[:-1], desc = f'Writing {org}'):
out.write(f'{s_i} ')
out.write(str(raw[-1]))
if __name__ == '__main__':
__main__(argv[1]) if len(argv) == 2 else print(f'Usage: python {__file__[__file__.rfind(path.sep) + 1:]} ./data_location')