Better handling of printing results board
This commit is contained in:
@ -1,44 +1,47 @@
|
||||
from typing import Final
|
||||
import numpy as np
|
||||
|
||||
DATA_DIR = "../data"
|
||||
OUT_DIR = "./out"
|
||||
MODEL_DIR = "./models"
|
||||
DATA_DIR: Final = '../data'
|
||||
OUT_DIR: Final = './out'
|
||||
MODEL_DIR: Final = './models'
|
||||
|
||||
NB_THREADS = 1024
|
||||
NB_THREADS_2D = (32, 32)
|
||||
NB_THREADS_3D = (16, 16, 4)
|
||||
M = int(np.log2(NB_THREADS_2D[1]))
|
||||
NB_THREADS: Final = 1024
|
||||
NB_THREADS_2D: Final = (32, 32)
|
||||
NB_THREADS_3D: Final = (16, 16, 4)
|
||||
M: Final = int(np.log2(NB_THREADS_2D[1]))
|
||||
|
||||
# Save state to avoid recalculation on restart
|
||||
SAVE_STATE = True
|
||||
SAVE_STATE: Final = True
|
||||
# Redo the state even if it's already saved
|
||||
FORCE_REDO = False
|
||||
FORCE_REDO: Final = False
|
||||
# Use NJIT to greatly accelerate runtime
|
||||
COMPILE_WITH_C = True
|
||||
COMPILE_WITH_C: Final = True
|
||||
# Use GPU to greatly accelerate runtime (as priority over NJIT)
|
||||
GPU_BOOSTED = True
|
||||
GPU_BOOSTED: Final = True
|
||||
# Depending on what you set, the output label will be as follow :
|
||||
# | COMPILE_WITH_C | GPU_BOOSTED | LABEL |
|
||||
# |----------------|-------------|-------|
|
||||
# | True | True | GPU |
|
||||
# | True | False | CPU |
|
||||
# | False | True | PGPU |
|
||||
# | False | False | PY |
|
||||
# ┌────────────────┬─────────────┬───────┐
|
||||
# │ COMPILE_WITH_C │ GPU_BOOSTED │ LABEL │
|
||||
# ├────────────────┼─────────────┼───────┤
|
||||
# │ True │ True │ GPU │
|
||||
# │ True │ False │ CPU │
|
||||
# │ False │ True │ PGPU │
|
||||
# │ False │ False │ PY │
|
||||
# └────────────────┴─────────────┴───────┘
|
||||
|
||||
# Number of weak classifiers
|
||||
# TS = [1]
|
||||
# TS = [1, 5, 10]
|
||||
TS = [1, 5, 10, 25, 50]
|
||||
# TS = [1, 5, 10, 25, 50, 100, 200]
|
||||
# TS = [1, 5, 10, 25, 50, 100, 200, 300]
|
||||
# TS = [1, 5, 10, 25, 50, 100, 200, 300, 400, 500, 1000]
|
||||
# TS: Final = [1]
|
||||
# TS: Final = [1, 5, 10]
|
||||
TS: Final = [1, 5, 10, 25, 50]
|
||||
# TS: Final = [1, 5, 10, 25, 50, 100, 200]
|
||||
# TS: Final = [1, 5, 10, 25, 50, 100, 200, 300]
|
||||
# TS: Final = [1, 5, 10, 25, 50, 100, 200, 300, 400, 500, 1000]
|
||||
|
||||
# Enable verbose output (for debugging purposes)
|
||||
__DEBUG = False
|
||||
__DEBUG: Final = False
|
||||
# Debugging options
|
||||
if __DEBUG:
|
||||
IDX_INSPECT = 4548
|
||||
IDX_INSPECT_OFFSET = 100
|
||||
IDX_INSPECT: Final = 4548
|
||||
IDX_INSPECT_OFFSET: Final = 100
|
||||
np.seterr(all = 'raise')
|
||||
# Debug option (image width * log_10(length) + extra characters)
|
||||
np.set_printoptions(linewidth = 19 * 6 + 3)
|
||||
|
110
python/projet.py
110
python/projet.py
@ -4,6 +4,7 @@
|
||||
from ViolaJones import train_viola_jones, classify_viola_jones
|
||||
from toolbox import state_saver, picke_multi_loader, format_time_ns, benchmark_function, unit_test_argsort_2d
|
||||
from toolbox_unit_test import format_time_ns_test
|
||||
from toolbox import header, footer, formatted_row, formatted_line
|
||||
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
|
||||
from sklearn.feature_selection import SelectPercentile, f_classif
|
||||
from common import load_datasets, unit_test
|
||||
@ -38,9 +39,11 @@ def preprocessing() -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
for folder_name in ["models", "out"]:
|
||||
makedirs(folder_name, exist_ok = True)
|
||||
|
||||
print(f"| {'Preprocessing':<49} | {'Time spent (ns)':<18} | {'Formatted time spent':<29} |\n|{'-'*51}|{'-'*20}|{'-'*31}|")
|
||||
preproc_timestamp = perf_counter_ns()
|
||||
preproc_gaps = [49, -18, 29]
|
||||
header(['Preprocessing', 'Time spent (ns)', 'Formatted time spent'], preproc_gaps)
|
||||
|
||||
X_train, y_train, X_test, y_test = state_saver("Loading sets", ["X_train", "y_train", "X_test", "y_test"],
|
||||
X_train, y_train, X_test, y_test = state_saver('Loading sets', preproc_gaps[0], ['X_train', 'y_train', 'X_test', 'y_test'],
|
||||
load_datasets, FORCE_REDO, SAVE_STATE)
|
||||
|
||||
if __DEBUG:
|
||||
@ -57,16 +60,17 @@ def preprocessing() -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
print(y_test.shape)
|
||||
print(y_test[IDX_INSPECT: IDX_INSPECT + IDX_INSPECT_OFFSET])
|
||||
|
||||
feats = state_saver("Building features", "feats", lambda: build_features(X_train.shape[1], X_train.shape[2]), FORCE_REDO, SAVE_STATE)
|
||||
feats = state_saver('Building features', preproc_gaps[0], 'feats', lambda: build_features(X_train.shape[1], X_train.shape[2]),
|
||||
FORCE_REDO, SAVE_STATE)
|
||||
|
||||
if __DEBUG:
|
||||
print("feats")
|
||||
print(feats.shape)
|
||||
print(feats[IDX_INSPECT].ravel())
|
||||
|
||||
X_train_ii = state_saver(f"Converting training set to integral images ({label})", f"X_train_ii_{label}",
|
||||
X_train_ii = state_saver(f'Converting training set to integral images ({label})', preproc_gaps[0], f'X_train_ii_{label}',
|
||||
lambda: set_integral_image(X_train), FORCE_REDO, SAVE_STATE)
|
||||
X_test_ii = state_saver(f"Converting testing set to integral images ({label})", f"X_test_ii_{label}",
|
||||
X_test_ii = state_saver(f'Converting testing set to integral images ({label})', preproc_gaps[0], f'X_test_ii_{label}',
|
||||
lambda: set_integral_image(X_test), FORCE_REDO, SAVE_STATE)
|
||||
|
||||
if __DEBUG:
|
||||
@ -77,9 +81,9 @@ def preprocessing() -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
print(X_test_ii.shape)
|
||||
print(X_test_ii[IDX_INSPECT])
|
||||
|
||||
X_train_feat = state_saver(f"Applying features to training set ({label})", f"X_train_feat_{label}",
|
||||
X_train_feat = state_saver(f'Applying features to training set ({label})', preproc_gaps[0], f'X_train_feat_{label}',
|
||||
lambda: apply_features(feats, X_train_ii), FORCE_REDO, SAVE_STATE)
|
||||
X_test_feat = state_saver(f"Applying features to testing set ({label})", f"X_test_feat_{label}",
|
||||
X_test_feat = state_saver(f'Applying features to testing set ({label})', preproc_gaps[0], f'X_test_feat_{label}',
|
||||
lambda: apply_features(feats, X_test_ii), FORCE_REDO, SAVE_STATE)
|
||||
del X_train_ii, X_test_ii, feats
|
||||
|
||||
@ -106,14 +110,14 @@ def preprocessing() -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
|
||||
# X_train_feat, X_test_feat = X_train_feat[indices], X_test_feat[indices]
|
||||
|
||||
X_train_feat_argsort = state_saver(f"Precalculating training set argsort ({label})", f"X_train_feat_argsort_{label}",
|
||||
X_train_feat_argsort = state_saver(f'Precalculating training set argsort ({label})', preproc_gaps[0], f'X_train_feat_argsort_{label}',
|
||||
lambda: argsort(X_train_feat), FORCE_REDO, SAVE_STATE)
|
||||
|
||||
if __DEBUG:
|
||||
print("X_train_feat_argsort")
|
||||
print(X_train_feat_argsort.shape)
|
||||
print(X_train_feat_argsort[IDX_INSPECT, : IDX_INSPECT_OFFSET])
|
||||
benchmark_function("Arg unit test", lambda: unit_test_argsort_2d(X_train_feat, X_train_feat_argsort))
|
||||
benchmark_function('Arg unit test', preproc_gaps[0], lambda: unit_test_argsort_2d(X_train_feat, X_train_feat_argsort))
|
||||
|
||||
X_test_feat_argsort = state_saver(f"Precalculating testing set argsort ({label})", f"X_test_feat_argsort_{label}",
|
||||
lambda: argsort(X_test_feat), FORCE_REDO, SAVE_STATE)
|
||||
@ -123,48 +127,70 @@ def preprocessing() -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
print(X_test_feat_argsort.shape)
|
||||
print(X_test_feat_argsort[IDX_INSPECT, : IDX_INSPECT_OFFSET])
|
||||
benchmark_function("Arg unit test", lambda: unit_test_argsort_2d(X_test_feat, X_test_feat_argsort))
|
||||
time_spent = perf_counter_ns() - preproc_timestamp
|
||||
formatted_line(preproc_gaps, '├', '┼', '─', '┤')
|
||||
formatted_row(preproc_gaps, ['Preprocessing summary', f'{time_spent:,}', format_time_ns(time_spent)])
|
||||
footer(preproc_gaps)
|
||||
|
||||
return X_train_feat, X_train_feat_argsort, y_train, X_test_feat, y_test
|
||||
|
||||
def train(X_train_feat: np.ndarray, X_train_feat_argsort: np.ndarray, y_train: np.ndarray) -> None:
|
||||
def train(X_train_feat: np.ndarray, X_train_feat_argsort: np.ndarray, y_train: np.ndarray) -> List[np.ndarray]:
|
||||
"""Train the weak classifiers.
|
||||
|
||||
Args:
|
||||
X_train (np.ndarray): Training images.
|
||||
X_test (np.ndarray): Testing Images.
|
||||
X_train_feat_argsort (np.ndarray): Sorted indexes of the training images features.
|
||||
y_train (np.ndarray): Training labels.
|
||||
|
||||
Returns: List of trained models
|
||||
"""
|
||||
print(f"\n| {'Training':<49} | {'Time spent (ns)':<18} | {'Formatted time spent':<29} |\n|{'-'*51}|{'-'*20}|{'-'*31}|")
|
||||
|
||||
training_timestamp = perf_counter_ns()
|
||||
training_gaps = [26, -18, 29]
|
||||
header(['Training', 'Time spent (ns)', 'Formatted time spent'], training_gaps)
|
||||
models = []
|
||||
|
||||
for T in TS:
|
||||
alphas, final_classifiers = state_saver(f"ViolaJones T = {T:<3} ({label})", [f"alphas_{T}_{label}", f"final_classifiers_{T}_{label}"],
|
||||
lambda: train_viola_jones(T, X_train_feat, X_train_feat_argsort, y_train), FORCE_REDO, SAVE_STATE, MODEL_DIR)
|
||||
alphas, final_classifiers = state_saver(f'ViolaJones T = {T:<4} ({label})', training_gaps[0],
|
||||
[f'alphas_{T}_{label}', f'final_classifiers_{T}_{label}'],
|
||||
lambda: train_viola_jones(T, X_train_feat, X_train_feat_argsort, y_train), FORCE_REDO, SAVE_STATE, MODEL_DIR)
|
||||
models.append([alphas, final_classifiers])
|
||||
|
||||
if __DEBUG:
|
||||
print("alphas")
|
||||
print(alphas)
|
||||
print("final_classifiers")
|
||||
print(final_classifiers)
|
||||
|
||||
def testing_and_evaluating(X_train_feat: np.ndarray, y_train: np.ndarray, X_test_feat: np.ndarray, y_test: np.ndarray) -> None:
|
||||
time_spent = perf_counter_ns() - training_timestamp
|
||||
formatted_line(training_gaps, '├', '┼', '─', '┤')
|
||||
formatted_row(training_gaps, ['Training summary', f'{time_spent:,}', format_time_ns(time_spent)])
|
||||
footer(training_gaps)
|
||||
|
||||
return models
|
||||
|
||||
def testing_and_evaluating(models: List[np.ndarray], X_train_feat: np.ndarray, y_train: np.ndarray, X_test_feat: np.ndarray, y_test: np.ndarray) -> None:
|
||||
"""Benchmark the trained classifiers on the training and testing sets.
|
||||
|
||||
Args:
|
||||
models (List[np.ndarray]): List of trained models.
|
||||
X_train_feat (np.ndarray): Training features.
|
||||
y_train (np.ndarray): Training labels.
|
||||
X_test_feat (np.ndarray): Testing features.
|
||||
y_test (np.ndarray): Testing labels.
|
||||
"""
|
||||
print(f"\n| {'Testing':<26} | Time spent (ns) (E) | {'Formatted time spent (E)':<29}", end = " | ")
|
||||
print(f"Time spent (ns) (T) | {'Formatted time spent (T)':<29} |")
|
||||
print(f"|{'-'*28}|{'-'*21}|{'-'*31}|{'-'*21}|{'-'*31}|")
|
||||
|
||||
perfs = []
|
||||
for T in TS:
|
||||
(alphas, final_classifiers) = picke_multi_loader([f"alphas_{T}_{label}", f"final_classifiers_{T}_{label}"])
|
||||
testing_gaps = [26, -19, 24, -19, 24]
|
||||
header(['Testing', 'Time spent (ns) (E)', 'Formatted time spent (E)', 'Time spent (ns) (T)', 'Formatted time spent (T)'], testing_gaps)
|
||||
|
||||
performances = []
|
||||
total_train_timestamp = 0
|
||||
total_test_timestamp = 0
|
||||
for T, (alphas, final_classifiers) in zip(TS, models):
|
||||
s = perf_counter_ns()
|
||||
y_pred_train = classify_viola_jones(alphas, final_classifiers, X_train_feat)
|
||||
t_pred_train = perf_counter_ns() - s
|
||||
total_train_timestamp += t_pred_train
|
||||
e_acc = accuracy_score(y_train, y_pred_train)
|
||||
e_f1 = f1_score(y_train, y_pred_train)
|
||||
(_, e_FP), (e_FN, _) = confusion_matrix(y_train, y_pred_train)
|
||||
@ -172,36 +198,48 @@ def testing_and_evaluating(X_train_feat: np.ndarray, y_train: np.ndarray, X_test
|
||||
s = perf_counter_ns()
|
||||
y_pred_test = classify_viola_jones(alphas, final_classifiers, X_test_feat)
|
||||
t_pred_test = perf_counter_ns() - s
|
||||
total_test_timestamp += t_pred_test
|
||||
t_acc = accuracy_score(y_test, y_pred_test)
|
||||
t_f1 = f1_score(y_test, y_pred_test)
|
||||
(_, t_FP), (t_FN, _) = confusion_matrix(y_test, y_pred_test)
|
||||
perfs.append((e_acc, e_f1, e_FN, e_FP, t_acc, t_f1, t_FN, t_FP))
|
||||
performances.append((e_acc, e_f1, e_FN, e_FP, t_acc, t_f1, t_FN, t_FP))
|
||||
|
||||
print(f"| {'ViolaJones T = ' + str(T):<19} {'(' + label + ')':<6}", end = " | ")
|
||||
print(f"{t_pred_train:>19,} | {format_time_ns(t_pred_train):<29}", end = " | ")
|
||||
print(f"{t_pred_test:>19,} | {format_time_ns(t_pred_test):<29} |")
|
||||
formatted_row(testing_gaps, [f"{'ViolaJones T = ' + str(T):<19} {'(' + label + ')':<6}", f'{t_pred_train:,}',
|
||||
format_time_ns(t_pred_train), f'{t_pred_test:,}', format_time_ns(t_pred_test)])
|
||||
|
||||
print(f"\n| {'Evaluating':<19} | ACC (E) | F1 (E) | FN (E) | FP (E) | ACC (T) | F1 (T) | FN (T) | FP (T) | ")
|
||||
print(f"|{'-'*21}|{'-'*9}|{'-'*8}|{'-'*8}|{'-'*8}|{'-'*9}|{'-'*8}|{'-'*8}|{'-'*8}|")
|
||||
formatted_line(testing_gaps, '├', '┼', '─', '┤')
|
||||
formatted_row(testing_gaps, ['Testing summary', f'{total_train_timestamp:,}', format_time_ns(total_train_timestamp), f'{total_test_timestamp:,}',
|
||||
format_time_ns(total_test_timestamp)])
|
||||
footer(testing_gaps)
|
||||
|
||||
for T, (e_acc, e_f1, e_FN, e_FP, t_acc, t_f1, t_FN, t_FP) in zip(TS, perfs):
|
||||
print(f"| {'ViolaJones T = ' + str(T):<19} | {e_acc:>7.2%} | {e_f1:>6.2f} | {e_FN:>6,} | {e_FP:>6,}", end = " | ")
|
||||
print(f"{t_acc:>7.2%} | {t_f1:>6.2f} | {t_FN:>6,} | {t_FP:>6,} |")
|
||||
evaluating_gaps = [19, 7, 6, 6, 6, 7, 6, 6, 6]
|
||||
header(['Evaluating', 'ACC (E)', 'F1 (E)', 'FN (E)', 'FP (E)', 'ACC (T)', 'F1 (T)', 'FN (T)', 'FP (T)'], evaluating_gaps)
|
||||
|
||||
for T, (e_acc, e_f1, e_FN, e_FP, t_acc, t_f1, t_FN, t_FP) in zip(TS, performances):
|
||||
print(f'│ ViolaJones T = {T:<4} │ {e_acc:>7.2%} │ {e_f1:>6.2f} │ {e_FN:>6,} │ {e_FP:>6,}', end = ' │ ')
|
||||
print(f'{t_acc:>7.2%} │ {t_f1:>6.2f} │ {t_FN:>6,} │ {t_FP:>6,} │')
|
||||
|
||||
footer(evaluating_gaps)
|
||||
|
||||
def main() -> None:
|
||||
print(f"| {'Unit testing':<49} | {'Time spent (ns)':<18} | {'Formatted time spent':<29} |")
|
||||
print(f"|{'-'*51}|{'-'*20}|{'-'*31}|")
|
||||
benchmark_function("Testing format_time_ns", format_time_ns_test)
|
||||
print()
|
||||
unit_timestamp = perf_counter_ns()
|
||||
unit_gaps = [27, -18, 29]
|
||||
header(['Unit testing', 'Time spent (ns)', 'Formatted time spent'], unit_gaps)
|
||||
benchmark_function('testing format_time', unit_gaps[0], format_time_test)
|
||||
benchmark_function('testing format_time_ns', unit_gaps[0], format_time_ns_test)
|
||||
time_spent = perf_counter_ns() - unit_timestamp
|
||||
formatted_line(unit_gaps, '├', '┼', '─', '┤')
|
||||
formatted_row(unit_gaps, ['Unit testing summary', f'{time_spent:,}', format_time_ns(time_spent)])
|
||||
footer(unit_gaps)
|
||||
|
||||
X_train_feat, X_train_feat_argsort, y_train, X_test_feat, y_test = preprocessing()
|
||||
train(X_train_feat, X_train_feat_argsort, y_train)
|
||||
models = train(X_train_feat, X_train_feat_argsort, y_train)
|
||||
|
||||
# X_train_feat, X_test_feat = picke_multi_loader([f"X_train_feat_{label}", f"X_test_feat_{label}"], OUT_DIR)
|
||||
# indices = picke_multi_loader(["indices"], OUT_DIR)[0]
|
||||
# X_train_feat, X_test_feat = X_train_feat[indices], X_test_feat[indices]
|
||||
|
||||
testing_and_evaluating(X_train_feat, y_train, X_test_feat, y_test)
|
||||
testing_and_evaluating(models, X_train_feat, y_train, X_test_feat, y_test)
|
||||
unit_test(TS)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -2,12 +2,35 @@ from typing import Any, Callable, List, Union, Final
|
||||
from time import perf_counter_ns
|
||||
from numba import njit
|
||||
import numpy as np
|
||||
from sys import stderr
|
||||
import pickle
|
||||
import os
|
||||
from config import MODEL_DIR, OUT_DIR
|
||||
from decorators import njit
|
||||
|
||||
time_formats: Final = ["ns", "µs", "ms", "s", "m", "h", "j", "w", "M", "y", "c"]
|
||||
def formatted_row(gaps: list[int], titles: list[str], separator: str = '│') -> None:
|
||||
for gap, title in zip(gaps, titles):
|
||||
print(f"{separator} {title:{'>' if gap < 0 else '<'}{abs(gap)}} ", end = '')
|
||||
print(separator)
|
||||
|
||||
def formatted_line(gaps: list[int], right: str, middle: str, separator: str, left: str) -> None:
|
||||
print(right, end = '')
|
||||
last_gap = len(gaps) - 1
|
||||
for i, gap in enumerate(gaps):
|
||||
print(f'{separator * (abs(gap) + 2)}', end = '')
|
||||
if i != last_gap:
|
||||
print(middle, end = '')
|
||||
print(left)
|
||||
|
||||
def header(titles: list[str], gaps: list[int]) -> None:
|
||||
formatted_line(gaps, '┌', '┬', '─', '┐')
|
||||
formatted_row(gaps, titles)
|
||||
formatted_line(gaps, '├', '┼', '─', '┤')
|
||||
|
||||
def footer(gaps: list[int]) -> None:
|
||||
formatted_line(gaps, '└', '┴', '─', '┘')
|
||||
|
||||
time_numbers: Final = np.array([1, 1e3, 1e6, 1e9, 6e10, 36e11, 864e11, 6048e11, 26784e11, 31536e12, 31536e14], dtype = np.uint64)
|
||||
@njit('str(uint64)')
|
||||
def format_time_ns(time: int) -> str:
|
||||
@ -53,7 +76,7 @@ def picke_multi_loader(filenames: List[str], save_dir: str = MODEL_DIR) -> List[
|
||||
b.append(None)
|
||||
return b
|
||||
|
||||
def benchmark_function(step_name: str, fnc: Callable) -> Any:
|
||||
def benchmark_function(step_name: str, column_width: int, fnc: Callable) -> Any:
|
||||
"""Benchmark a function and display the result of stdout.
|
||||
|
||||
Args:
|
||||
@ -63,14 +86,15 @@ def benchmark_function(step_name: str, fnc: Callable) -> Any:
|
||||
Returns:
|
||||
Any: Result of the function.
|
||||
"""
|
||||
print(f"{step_name}...", end = "\r")
|
||||
print(f'{step_name}...', file = stderr, end = '\r')
|
||||
s = perf_counter_ns()
|
||||
b = fnc()
|
||||
e = perf_counter_ns() - s
|
||||
print(f"| {step_name:<49} | {e:>18,} | {format_time_ns(e):<29} |")
|
||||
print(f'│ {step_name:<{column_width}} │ {e:>18,} │ {format_time_ns(e):<29} │')
|
||||
return b
|
||||
|
||||
def state_saver(step_name: str, filename: Union[str, List[str]], fnc, force_redo: bool = False, save_state: bool = True, save_dir: str = OUT_DIR) -> Any:
|
||||
def state_saver(step_name: str, column_width: int, filename: Union[str, List[str]], fnc, force_redo: bool = False,
|
||||
save_state: bool = True, save_dir: str = OUT_DIR) -> Any:
|
||||
"""Either execute a function then saves the result or load the already existing result.
|
||||
|
||||
Args:
|
||||
@ -85,18 +109,18 @@ def state_saver(step_name: str, filename: Union[str, List[str]], fnc, force_redo
|
||||
"""
|
||||
if isinstance(filename, str):
|
||||
if not os.path.exists(f"{save_dir}/{filename}.pkl") or force_redo:
|
||||
b = benchmark_function(step_name, fnc)
|
||||
b = benchmark_function(step_name, column_width, fnc)
|
||||
if save_state:
|
||||
print(f"Saving results of {step_name}", end = '\r')
|
||||
with open(f"{save_dir}/{filename}.pkl", 'wb') as f:
|
||||
print(f'Saving results of {step_name}', file = stderr, end = '\r')
|
||||
pickle.dump(b, f)
|
||||
print(' ' * 100, end = '\r')
|
||||
print(' ' * 100, file = stderr, end = '\r')
|
||||
return b
|
||||
else:
|
||||
print(f"Loading results of {step_name}", end = '\r')
|
||||
with open(f"{save_dir}/{filename}.pkl", "rb") as f:
|
||||
print(f'Loading results of {step_name}', file = stderr, end = '\r')
|
||||
res = pickle.load(f)
|
||||
print(f"| {step_name:<49} | {'None':>18} | {'loaded saved state':<29} |")
|
||||
print(f"│ {step_name:<{column_width}} │ {'None':>18} │ {'loaded saved state':<29} │")
|
||||
return res
|
||||
elif isinstance(filename, list):
|
||||
abs = False
|
||||
@ -105,22 +129,22 @@ def state_saver(step_name: str, filename: Union[str, List[str]], fnc, force_redo
|
||||
abs = True
|
||||
break
|
||||
if abs or force_redo:
|
||||
b = benchmark_function(step_name, fnc)
|
||||
b = benchmark_function(step_name, column_width, fnc)
|
||||
if save_state:
|
||||
print(f"Saving results of {step_name}", end = '\r')
|
||||
print(f'Saving results of {step_name}', file = stderr, end = '\r')
|
||||
for bi, fnI in zip(b, filename):
|
||||
with open(f"{save_dir}/{fnI}.pkl", 'wb') as f:
|
||||
pickle.dump(bi, f)
|
||||
print(' ' * 100, end = '\r')
|
||||
print(' ' * 100, file = stderr, end = '\r')
|
||||
return b
|
||||
|
||||
print(f"| {step_name:<49} | {'None':>18} | {'loaded saved state':<29} |")
|
||||
print(f"│ {step_name:<{column_width}} │ {'None':>18} │ {'loaded saved state':<29} │")
|
||||
b = []
|
||||
print(f"Loading results of {step_name}", end = '\r')
|
||||
print(f'Loading results of {step_name}', file = stderr, end = '\r')
|
||||
for fn in filename:
|
||||
with open(f"{save_dir}/{fn}.pkl", "rb") as f:
|
||||
b.append(pickle.load(f))
|
||||
print(' ' * 100, end = '\r')
|
||||
print(' ' * 100, file = stderr, end = '\r')
|
||||
return b
|
||||
else:
|
||||
assert False, f"Incompatible filename type = {type(filename)}"
|
||||
|
Reference in New Issue
Block a user