ode_TrainingDeepONetIVP
Functions which trains a DeepONet for solving a solveODE_DeepONet_IVP calls implemented in JAX/FLAX
Each of order 1,2, and 3 equations with hard and soft constraints have there own training file respectively. Almost all functions are same, with slight changes made to model and model-interacting functions for specific problems. Here we outline all these commonly used functions. All files can be found in github in ODE/SpecificTraining/DeepONet/DeepONetIVP directory.
While available to user, not meant to be used. Meant to be used through object returned from solveODE calls, where training file is selected through ode_trainingSelect
startTraining(eval_points, inits, order, t, N_sensors, sensor_range, epochs, net_layers, net_units, eqn)
Main function of training, called by PINNtrainSelect_DeepONet.
| Parameters: |
|
|---|
Function generates DeepONet, de_points, sensors, and parameters. Then calls train_network on DeepONet. Then gets network solution prediction.
| Returns: |
|
|---|
Normalize
Bases: Module
Class which describes a normalize layer for DeepONet. Returns input data normalized to interval [-1, 1].
CombineBranches
Bases: Module
Class which combines data from two branch nets and returns resulting combination.
HardConstraint
Bases: Module
Class which applys hard constraint of inital values to network. Returns input data after being hard constrainted.
MLP
Bases: Module
Class which describes MLP used as basis for branch and trunk nets. Uses user input net_layers and net_units.
DeepONet
Bases: Module
Class which describes DeepONet model. Creates MLP-based trunk and branch networks for what is needed based on order of problem, normalizes, and hard constraints if hard constraint equation
train_network(params, des, zsensor, epochs)
Main function which calls train_step. Packages data, performs training routine, and does network optimization.
| Parameters: |
|
|---|
train_step(params, pdes, z)
Main training function. Defines derivatives of network, and defines loss function to minimzie input equation and inital values if soft constrainting.
| Parameters: |
|
|---|
May also have arguments zt, ztt, etc. for derivative sensors
defineCollocationPoints(t_bdry, N, sensor_range)
File specific collocationPoints function to generate sampled ode_points and sensors. Differnet orders require slightly different functions.
| Parameters: |
|
|---|