downloader/convert_dataset.py : Added better typing and formatting

This commit is contained in:
saundersp 2024-11-08 01:23:38 +01:00
parent e9df962d7a
commit 2051ae8cdc

View File

@ -3,29 +3,31 @@ from tqdm import tqdm
from functools import partial from functools import partial
from sys import argv from sys import argv
import numpy as np import numpy as np
from numpy.typing import NDArray
from typing import Final, Callable
from os import path, listdir from os import path, listdir
# Induce determinism # Induce determinism
np.random.seed(133742) np.random.seed(196_863)
# Makes the "leave" argument default to False # Makes the 'leave' argument default to False
tqdm = partial(tqdm, leave = False) tqdm: Callable = partial(tqdm, leave = False)
def read_pgm(pgm_file: BufferedReader) -> np.ndarray: def read_pgm(pgm_file: BufferedReader) -> NDArray[np.uint8]:
"""Read the data of a PGM file """Read the data of a PGM file
Args: Args:
pgm_file (BufferedReader): PGM File pgm_file (BufferedReader): PGM File
Returns: Returns:
np.ndarray: PGM data NDArray[np.uint8]: PGM data
""" """
assert (f := pgm_file.readline()) == b'P5\n', f"Incorrect file format: {f}" assert (f := pgm_file.readline()) == b'P5\n', f'Incorrect file format: {f}'
(width, height) = [int(i) for i in pgm_file.readline().split()] (width, height) = (int(i) for i in pgm_file.readline().split())
assert width > 0 and height > 0, f"Incorrect dimensions: {width}x{height}" assert width > 0 and height > 0, f'Incorrect dimensions: {width}x{height}'
assert (depth := int(pgm_file.readline())) < 256, f"Incorrect depth: {depth}" assert (depth := int(pgm_file.readline())) < 256, f'Incorrect depth: {depth}'
buff = np.empty(height * width, dtype = np.uint8) buff: Final[NDArray[np.uint8]] = np.empty(height * width, dtype = np.uint8)
for i in range(buff.shape[0]): for i in range(buff.shape[0]):
buff[i] = ord(pgm_file.read(1)) buff[i] = ord(pgm_file.read(1))
return buff.reshape((height, width)) return buff.reshape((height, width))
@ -36,25 +38,25 @@ def __main__(data_path: str) -> None:
Args: Args:
data_path (str): Path of the PGM files data_path (str): Path of the PGM files
""" """
for set_name in tqdm(["train", "test"], desc = "set name"): for set_name in tqdm(['train', 'test'], desc = 'set name'):
X, y = [], [] X, y = [], []
for y_i, label in enumerate(tqdm(["non-face", "face"], desc = "label")): for y_i, label in enumerate(tqdm(['non-face', 'face'], desc = 'label')):
for filename in tqdm(listdir(f"{data_path}/{set_name}/{label}"), desc = "Reading pgm file"): for filename in tqdm(listdir(f'{data_path}/{set_name}/{label}'), desc = 'Reading pgm file'):
with open(f"{data_path}/{set_name}/{label}/{filename}", "rb") as face: with open(f'{data_path}/{set_name}/{label}/{filename}', 'rb') as face:
X.append(read_pgm(face)) X.append(read_pgm(face))
y.append(y_i) y.append(y_i)
X, y = np.asarray(X), np.asarray(y) X, y = np.asarray(X), np.asarray(y)
idx = np.random.permutation(y.shape[0]) idx: NDArray[np.int64] = np.random.permutation(y.shape[0])
X, y = X[idx], y[idx] X, y = X[idx], y[idx]
for org, s in tqdm(zip("Xy", [X, y]), desc = f"Writing {set_name}"): for org, s in tqdm(zip('Xy', [X, y]), desc = f'Writing {set_name}'):
with open(f"{data_path}/{org}_{set_name}.bin", "w") as out: with open(f'{data_path}/{org}_{set_name}.bin', 'w') as out:
out.write(f'{str(s.shape)[1:-1].replace(",", "")}\n') out.write(f'{str(s.shape)[1:-1].replace(',', '')}\n')
raw = s.ravel() raw: NDArray = s.ravel()
for s_i in tqdm(raw[:-1], desc = f"Writing {org}"): for s_i in tqdm(raw[:-1], desc = f'Writing {org}'):
out.write(f"{s_i} ") out.write(f'{s_i} ')
out.write(str(raw[-1])) out.write(str(raw[-1]))
if __name__ == "__main__": if __name__ == '__main__':
__main__(argv[1]) if len(argv) == 2 else print(f"Usage: python {__file__[__file__.rfind(path.sep) + 1:]} ./data_location") __main__(argv[1]) if len(argv) == 2 else print(f'Usage: python {__file__[__file__.rfind(path.sep) + 1:]} ./data_location')