codes package#
Subpackages#
- codes.benchmark package
- Submodules
- codes.benchmark.bench_fcts module
compare_MAE()
compare_UQ()
compare_batchsize()
compare_dynamic_accuracy()
compare_extrapolation()
compare_inference_time()
compare_interpolation()
compare_main_losses()
compare_models()
compare_relative_errors()
compare_sparse()
evaluate_UQ()
evaluate_accuracy()
evaluate_batchsize()
evaluate_compute()
evaluate_dynamic_accuracy()
evaluate_extrapolation()
evaluate_interpolation()
evaluate_sparse()
run_benchmark()
tabular_comparison()
time_inference()
- codes.benchmark.bench_plots module
get_custom_palette()
inference_time_bar_plot()
int_ext_sparse()
plot_MAE_comparison()
plot_MAE_comparison_train_duration()
plot_average_errors_over_time()
plot_average_uncertainty_over_time()
plot_comparative_dynamic_correlation_heatmaps()
plot_comparative_error_correlation_heatmaps()
plot_dynamic_correlation()
plot_dynamic_correlation_heatmap()
plot_error_correlation_heatmap()
plot_error_distribution_comparative()
plot_error_distribution_per_chemical()
plot_example_predictions_with_uncertainty()
plot_generalization_error_comparison()
plot_generalization_errors()
plot_loss_comparison()
plot_losses()
plot_relative_errors()
plot_relative_errors_over_time()
plot_surr_losses()
plot_uncertainty_over_time_comparison()
plot_uncertainty_vs_errors()
rel_errors_and_uq()
save_plot()
save_plot_counter()
- codes.benchmark.bench_utils module
check_benchmark()
check_surrogate()
clean_metrics()
convert_dict_to_scientific_notation()
convert_to_standard_types()
count_trainable_parameters()
discard_numpy_entries()
flatten_dict()
format_seconds()
format_time()
get_model_config()
get_required_models_list()
get_surrogate()
load_model()
make_comparison_csv()
measure_memory_footprint()
read_yaml_config()
write_metrics_to_yaml()
- Module contents
check_benchmark()
check_surrogate()
clean_metrics()
compare_MAE()
compare_UQ()
compare_batchsize()
compare_dynamic_accuracy()
compare_extrapolation()
compare_inference_time()
compare_interpolation()
compare_main_losses()
compare_models()
compare_relative_errors()
compare_sparse()
convert_dict_to_scientific_notation()
convert_to_standard_types()
count_trainable_parameters()
discard_numpy_entries()
evaluate_UQ()
evaluate_accuracy()
evaluate_batchsize()
evaluate_compute()
evaluate_dynamic_accuracy()
evaluate_extrapolation()
evaluate_interpolation()
evaluate_sparse()
flatten_dict()
format_seconds()
format_time()
get_custom_palette()
get_model_config()
get_required_models_list()
get_surrogate()
inference_time_bar_plot()
int_ext_sparse()
load_model()
make_comparison_csv()
measure_memory_footprint()
plot_MAE_comparison()
plot_MAE_comparison_train_duration()
plot_average_errors_over_time()
plot_average_uncertainty_over_time()
plot_comparative_dynamic_correlation_heatmaps()
plot_comparative_error_correlation_heatmaps()
plot_dynamic_correlation()
plot_dynamic_correlation_heatmap()
plot_error_correlation_heatmap()
plot_error_distribution_comparative()
plot_error_distribution_per_chemical()
plot_example_predictions_with_uncertainty()
plot_generalization_error_comparison()
plot_generalization_errors()
plot_loss_comparison()
plot_losses()
plot_relative_errors()
plot_relative_errors_over_time()
plot_surr_losses()
plot_uncertainty_over_time_comparison()
plot_uncertainty_vs_errors()
read_yaml_config()
rel_errors_and_uq()
run_benchmark()
save_plot()
save_plot_counter()
tabular_comparison()
time_inference()
write_metrics_to_yaml()
- codes.surrogates package
- Submodules
- codes.surrogates.surrogate_classes module
- codes.surrogates.surrogates module
AbstractSurrogateModel
AbstractSurrogateModel.train_loss
AbstractSurrogateModel.test_loss
AbstractSurrogateModel.MAE
AbstractSurrogateModel.normalisation
AbstractSurrogateModel.train_duration
AbstractSurrogateModel.device
AbstractSurrogateModel.n_chemicals
AbstractSurrogateModel.n_timesteps
AbstractSurrogateModel.L1
AbstractSurrogateModel.config
AbstractSurrogateModel.denormalize()
AbstractSurrogateModel.fit()
AbstractSurrogateModel.forward()
AbstractSurrogateModel.load()
AbstractSurrogateModel.predict()
AbstractSurrogateModel.prepare_data()
AbstractSurrogateModel.save()
AbstractSurrogateModel.setup_progress_bar()
- Module contents
AbstractSurrogateModel
AbstractSurrogateModel.train_loss
AbstractSurrogateModel.test_loss
AbstractSurrogateModel.MAE
AbstractSurrogateModel.normalisation
AbstractSurrogateModel.train_duration
AbstractSurrogateModel.device
AbstractSurrogateModel.n_chemicals
AbstractSurrogateModel.n_timesteps
AbstractSurrogateModel.L1
AbstractSurrogateModel.config
AbstractSurrogateModel.denormalize()
AbstractSurrogateModel.fit()
AbstractSurrogateModel.forward()
AbstractSurrogateModel.load()
AbstractSurrogateModel.predict()
AbstractSurrogateModel.prepare_data()
AbstractSurrogateModel.save()
AbstractSurrogateModel.setup_progress_bar()
BranchNet
ChemDataset
Decoder
Encoder
FullyConnected
FullyConnectedNet
LatentNeuralODE
LatentPoly
ModelWrapper
ModelWrapper.config
ModelWrapper.loss_weights
ModelWrapper.encoder
ModelWrapper.decoder
ModelWrapper.ode
ModelWrapper.forward()
ModelWrapper.renormalize_loss_weights()
ModelWrapper.total_loss()
ModelWrapper.identity_loss()
ModelWrapper.l2_loss()
ModelWrapper.deriv_loss()
ModelWrapper.deriv2_loss()
ModelWrapper.deriv()
ModelWrapper.deriv2()
ModelWrapper.deriv()
ModelWrapper.deriv2()
ModelWrapper.deriv2_loss()
ModelWrapper.deriv_loss()
ModelWrapper.forward()
ModelWrapper.identity_loss()
ModelWrapper.l2_loss()
ModelWrapper.renormalize_loss_weights()
ModelWrapper.total_loss()
MultiONet
ODE
Polynomial
TrunkNet
- codes.train package
- codes.utils package
- Submodules
- codes.utils.data_utils module
- codes.utils.utils module
- Module contents
check_and_load_data()
check_training_status()
create_dataset()
create_hdf5_dataset()
create_model_dir()
download_data()
get_data_subset()
get_progress_bar()
load_and_save_config()
load_task_list()
make_description()
nice_print()
normalize_data()
read_yaml_config()
save_task_list()
set_random_seeds()
time_execution()
worker_init_fn()