Skip to content

LSDOlab/python_csdl_backend

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Python CSDL Backend

Repository for new CSDL backend.

Changelog

  • Added optional automatic CPU parallelization using MPI.

    • Performs static list scheduling that generates blocking MPI code for model evaluation and derivatives. May or may not improve performance depending on model structure.

    • To run without parallelization (same as before):

      import csdl
      import python_csdl_backend
      sim = python_csdl_backend.Simulator(Sample())

      Terminal:

      python sample.py
    • To run with parallelization (new):

      import csdl
      import python_csdl_backend
      from mpi4py import MPI # conda install mpi4py
      comm = MPI.COMM_WORLD
      sim = python_csdl_backend.Simulator(Sample(), comm = comm)

      Terminal:

      mpirun -n {number of processors} python sample.py
  • Added optional automatic checkpointing for adjoint computation.

    • Using checkpointing can reduce peak memory usage but can be ~2x slower than without checkpointing.
    • To compute derivatives using checkpointing, specify additional arguments to Simulator constructor:
      import csdl
      import python_csdl_backend
      sim = python_csdl_backend.Simulator(
          Sample(),
          checkpoints: bool = True, # Required
          save_vars: set() = {'sample_model.variable_of_interest'}, # Optional
          checkpoint_stride: int = 150_000_000, # Optional
      )
    • save_vars: Checkpointing may delete variables of interest during computation which can prevent postprocessing (printing/plotting values after computation). Specify names of variables you wish to permanently allocate memory for in save_vars. Has no effect when checkpoints = False.
    • checkpoint_stride: Rough estimate of how large a checkpoint interval is in bytes. If left empty, a stride is automatically chosen.
    • When checkpoints = True, additional memory-saving measures are performed such as lazy evaluation of linear partial derivatives as opposed to pre-computation.
  • Other

    • Visualize checkpoints and parallelization:
      python_csdl_backend.Simulator(Sample(), visualize_schedule = True)
    • Significantly reduced memory usage for derivative computation (even without checkpointing) by deallocating partial derivatives and propagated adjoints as they are processed.
      • May result in reduced performance for small models compared to previous version.
    • Evaluate (some) partial derivatives lazily instead of precomputating them to potentially save memory.
      python_csdl_backend.Simulator(Sample(), lazy = True)
    • Added new tests with more complicated models.
      • Run parallel tests using (in root directory):
        mpirun -n {number of processors} pytest
      • Also tests now verify each model with:
        • checkpointing, [ ] parallelization (previous implementation)
        • checkpointing, [ ] parallelization
        • checkpointing, [x] parallelization
        • checkpointing, [x] parallelization

Installation

Install with pip install -e . in root directory.

Run unit tests via pytest in root directory.

Requires csdl: https://github.com/LSDOlab/csdl

Features

List of features implemented: https://docs.google.com/spreadsheets/d/1WWdCow1ZNf7Tx5pk0ql7nIZ5cehOLWBe8dN6LLhGJ0A/edit?usp=sharing

Note there is currently no way to visualize model implementation.

import csdl
import python_csdl_backend


class Sample(csdl.Model):

    def define(self):
        x1 = self.create_input('x1', val=1.0)
        x2 = self.create_input('x2', val=5.0)

        y1 = csdl.sin(x1)
        y2 = -x2
        y3 = csdl.exp(y2)

        y4 = y1+y3

        f = y4+3

        self.register_output('f', f)


# Simulator should be identical for both backends
sim = python_csdl_backend.Simulator(Sample())

# run
sim.run()

# Print outputs
print('\noutputs:')
print('x1 = ', sim['x1'])
print('x2 = ', sim['x2'])
print('f = ', sim['f'])

# compute totals:
totals = sim.compute_totals(of=['f'], wrt=['x2', 'x1'])
print('\nderivatives:')
for key in totals:
    print(key, '=', totals[key])

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 3

  •  
  •  
  •  

Languages