downloader/convert_dataset.py : Added better typing and formatting
This commit is contained in:
parent
e9df962d7a
commit
2051ae8cdc
@ -3,29 +3,31 @@ from tqdm import tqdm
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from sys import argv
|
from sys import argv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
from typing import Final, Callable
|
||||||
from os import path, listdir
|
from os import path, listdir
|
||||||
|
|
||||||
# Induce determinism
|
# Induce determinism
|
||||||
np.random.seed(133742)
|
np.random.seed(196_863)
|
||||||
|
|
||||||
# Makes the "leave" argument default to False
|
# Makes the 'leave' argument default to False
|
||||||
tqdm = partial(tqdm, leave = False)
|
tqdm: Callable = partial(tqdm, leave = False)
|
||||||
|
|
||||||
def read_pgm(pgm_file: BufferedReader) -> np.ndarray:
|
def read_pgm(pgm_file: BufferedReader) -> NDArray[np.uint8]:
|
||||||
"""Read the data of a PGM file
|
"""Read the data of a PGM file
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pgm_file (BufferedReader): PGM File
|
pgm_file (BufferedReader): PGM File
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
np.ndarray: PGM data
|
NDArray[np.uint8]: PGM data
|
||||||
"""
|
"""
|
||||||
assert (f := pgm_file.readline()) == b'P5\n', f"Incorrect file format: {f}"
|
assert (f := pgm_file.readline()) == b'P5\n', f'Incorrect file format: {f}'
|
||||||
(width, height) = [int(i) for i in pgm_file.readline().split()]
|
(width, height) = (int(i) for i in pgm_file.readline().split())
|
||||||
assert width > 0 and height > 0, f"Incorrect dimensions: {width}x{height}"
|
assert width > 0 and height > 0, f'Incorrect dimensions: {width}x{height}'
|
||||||
assert (depth := int(pgm_file.readline())) < 256, f"Incorrect depth: {depth}"
|
assert (depth := int(pgm_file.readline())) < 256, f'Incorrect depth: {depth}'
|
||||||
|
|
||||||
buff = np.empty(height * width, dtype = np.uint8)
|
buff: Final[NDArray[np.uint8]] = np.empty(height * width, dtype = np.uint8)
|
||||||
for i in range(buff.shape[0]):
|
for i in range(buff.shape[0]):
|
||||||
buff[i] = ord(pgm_file.read(1))
|
buff[i] = ord(pgm_file.read(1))
|
||||||
return buff.reshape((height, width))
|
return buff.reshape((height, width))
|
||||||
@ -36,25 +38,25 @@ def __main__(data_path: str) -> None:
|
|||||||
Args:
|
Args:
|
||||||
data_path (str): Path of the PGM files
|
data_path (str): Path of the PGM files
|
||||||
"""
|
"""
|
||||||
for set_name in tqdm(["train", "test"], desc = "set name"):
|
for set_name in tqdm(['train', 'test'], desc = 'set name'):
|
||||||
X, y = [], []
|
X, y = [], []
|
||||||
for y_i, label in enumerate(tqdm(["non-face", "face"], desc = "label")):
|
for y_i, label in enumerate(tqdm(['non-face', 'face'], desc = 'label')):
|
||||||
for filename in tqdm(listdir(f"{data_path}/{set_name}/{label}"), desc = "Reading pgm file"):
|
for filename in tqdm(listdir(f'{data_path}/{set_name}/{label}'), desc = 'Reading pgm file'):
|
||||||
with open(f"{data_path}/{set_name}/{label}/{filename}", "rb") as face:
|
with open(f'{data_path}/{set_name}/{label}/{filename}', 'rb') as face:
|
||||||
X.append(read_pgm(face))
|
X.append(read_pgm(face))
|
||||||
y.append(y_i)
|
y.append(y_i)
|
||||||
|
|
||||||
X, y = np.asarray(X), np.asarray(y)
|
X, y = np.asarray(X), np.asarray(y)
|
||||||
idx = np.random.permutation(y.shape[0])
|
idx: NDArray[np.int64] = np.random.permutation(y.shape[0])
|
||||||
X, y = X[idx], y[idx]
|
X, y = X[idx], y[idx]
|
||||||
|
|
||||||
for org, s in tqdm(zip("Xy", [X, y]), desc = f"Writing {set_name}"):
|
for org, s in tqdm(zip('Xy', [X, y]), desc = f'Writing {set_name}'):
|
||||||
with open(f"{data_path}/{org}_{set_name}.bin", "w") as out:
|
with open(f'{data_path}/{org}_{set_name}.bin', 'w') as out:
|
||||||
out.write(f'{str(s.shape)[1:-1].replace(",", "")}\n')
|
out.write(f'{str(s.shape)[1:-1].replace(',', '')}\n')
|
||||||
raw = s.ravel()
|
raw: NDArray = s.ravel()
|
||||||
for s_i in tqdm(raw[:-1], desc = f"Writing {org}"):
|
for s_i in tqdm(raw[:-1], desc = f'Writing {org}'):
|
||||||
out.write(f"{s_i} ")
|
out.write(f'{s_i} ')
|
||||||
out.write(str(raw[-1]))
|
out.write(str(raw[-1]))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
__main__(argv[1]) if len(argv) == 2 else print(f"Usage: python {__file__[__file__.rfind(path.sep) + 1:]} ./data_location")
|
__main__(argv[1]) if len(argv) == 2 else print(f'Usage: python {__file__[__file__.rfind(path.sep) + 1:]} ./data_location')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user