codes package#
Subpackages#
- codes.benchmark package
- Submodules
- codes.benchmark.bench_fcts module
compare_UQ()compare_batchsize()compare_errors()compare_extrapolation()compare_gradients()compare_inference_time()compare_interpolation()compare_iterative()compare_main_losses()compare_models()compare_sparse()evaluate_UQ()evaluate_accuracy()evaluate_batchsize()evaluate_compute()evaluate_extrapolation()evaluate_gradients()evaluate_interpolation()evaluate_iterative_predictions()evaluate_sparse()run_benchmark()tabular_comparison()time_inference()
- codes.benchmark.bench_plots module
get_custom_palette()inference_time_bar_plot()plot_MAE_comparison()plot_all_generalization_errors()plot_average_errors_over_time()plot_average_uncertainty_over_time()plot_catastrophic_detection_curves()plot_comparative_error_correlation_heatmaps()plot_comparative_gradient_heatmaps()plot_dynamic_correlation()plot_error_distribution_comparative()plot_error_distribution_per_quantity()plot_error_percentiles_over_time()plot_errors_over_time()plot_example_iterative_predictions()plot_example_mode_predictions()plot_example_predictions_with_uncertainty()plot_generalization_error_comparison()plot_generalization_errors()plot_gradients_heatmap()plot_loss_comparison()plot_loss_comparison_equal()plot_loss_comparison_train_duration()plot_losses()plot_losses_dual_axis()plot_mean_deltadex_over_time_main_vs_ensemble()plot_relative_errors()plot_surr_losses()plot_uncertainty_confidence()plot_uncertainty_heatmap()plot_uncertainty_over_time_comparison()plot_uncertainty_vs_errors()rel_errors_and_uq()save_plot()save_plot_counter()
- codes.benchmark.bench_utils module
check_benchmark()check_surrogate()clean_metrics()convert_dict_to_scientific_notation()convert_to_standard_types()count_trainable_parameters()discard_numpy_entries()flatten_dict()format_seconds()format_time()format_value()get_model_config()get_required_models_list()get_surrogate()load_model()make_comparison_csv()measure_inference_time()measure_memory_footprint()save_table_csv()write_metrics_to_yaml()
- Module contents
check_benchmark()check_surrogate()clean_metrics()compare_UQ()compare_batchsize()compare_errors()compare_extrapolation()compare_gradients()compare_inference_time()compare_interpolation()compare_main_losses()compare_models()compare_sparse()convert_dict_to_scientific_notation()convert_to_standard_types()count_trainable_parameters()discard_numpy_entries()evaluate_UQ()evaluate_accuracy()evaluate_batchsize()evaluate_compute()evaluate_extrapolation()evaluate_gradients()evaluate_interpolation()evaluate_sparse()flatten_dict()format_seconds()format_time()format_value()get_custom_palette()get_model_config()get_required_models_list()get_surrogate()inference_time_bar_plot()load_model()make_comparison_csv()measure_inference_time()measure_memory_footprint()plot_MAE_comparison()plot_all_generalization_errors()plot_average_errors_over_time()plot_average_uncertainty_over_time()plot_comparative_error_correlation_heatmaps()plot_comparative_gradient_heatmaps()plot_dynamic_correlation()plot_error_distribution_comparative()plot_error_distribution_per_quantity()plot_error_percentiles_over_time()plot_example_iterative_predictions()plot_example_mode_predictions()plot_example_predictions_with_uncertainty()plot_generalization_error_comparison()plot_generalization_errors()plot_gradients_heatmap()plot_loss_comparison()plot_loss_comparison_train_duration()plot_losses()plot_relative_errors()plot_surr_losses()plot_uncertainty_confidence()plot_uncertainty_heatmap()plot_uncertainty_over_time_comparison()plot_uncertainty_vs_errors()read_yaml_config()rel_errors_and_uq()run_benchmark()save_plot()save_plot_counter()save_table_csv()tabular_comparison()time_inference()write_metrics_to_yaml()
- codes.surrogates package
- Submodules
- codes.surrogates.surrogate_classes module
- codes.surrogates.surrogates module
- Module contents
AbstractSurrogateModelAbstractSurrogateModel.train_lossAbstractSurrogateModel.test_lossAbstractSurrogateModel.MAEAbstractSurrogateModel.normalisationAbstractSurrogateModel.train_durationAbstractSurrogateModel.deviceAbstractSurrogateModel.n_quantitiesAbstractSurrogateModel.n_timestepsAbstractSurrogateModel.L1AbstractSurrogateModel.configAbstractSurrogateModel.checkpoint()AbstractSurrogateModel.denormalize()AbstractSurrogateModel.denormalize_old()AbstractSurrogateModel.fit()AbstractSurrogateModel.forward()AbstractSurrogateModel.get_checkpoint()AbstractSurrogateModel.get_registered_classes()AbstractSurrogateModel.load()AbstractSurrogateModel.predict()AbstractSurrogateModel.prepare_data()AbstractSurrogateModel.register()AbstractSurrogateModel.save()AbstractSurrogateModel.setup_checkpoint()AbstractSurrogateModel.setup_optimizer_and_scheduler()AbstractSurrogateModel.setup_progress_bar()AbstractSurrogateModel.time_pruning()AbstractSurrogateModel.validate()
BranchNetChemDatasetDecoderEncoderFlatSeqBatchIterableFullyConnectedFullyConnectedNetLatentNeuralODELatentPolyModelWrapperMultiONetODEPolynomialTrunkNet
- codes.train package
- codes.tune package
- Submodules
- codes.tune.evaluate_study module
- codes.tune.evaluate_tuning module
- codes.tune.optuna_fcts module
- codes.tune.postgres_fcts module
- codes.tune.tune_utils module
- Module contents
MaxValidTrialsCallbackbuild_fine_optuna_params()build_study_names()copy_config()create_objective()delete_studies_if_requested()initialize_optuna_database()load_model_test_losses()load_study_config()load_yaml_config()make_optuna_params()maybe_set_runtime_threshold()moving_average()plot_test_losses()prepare_workspace()training_run()yes_no()
- codes.utils package
- Submodules
- codes.utils.data_utils module
- codes.utils.utils module
batch_factor_to_float()check_training_status()create_model_dir()determine_batch_size()get_progress_bar()load_and_save_config()load_task_list()make_description()nice_print()parse_for_none()parse_hyperparameters()read_yaml_config()save_task_list()set_random_seeds()time_execution()worker_init_fn()
- Module contents
batch_factor_to_float()check_and_load_data()check_training_status()create_dataset()create_hdf5_dataset()create_model_dir()determine_batch_size()download_data()get_data_subset()get_progress_bar()load_and_save_config()load_task_list()make_description()nice_print()normalize_data()parse_hyperparameters()read_yaml_config()save_task_list()set_random_seeds()time_execution()worker_init_fn()