Throughout the API we often take parameters like kT, dt, and alpha. In some cases, these are typed as float | torch.Tensor, in others they are just torch.Tensor. We should:
- standardize towards
float | torch.Tensor in init and step functions.
- Use
x = torch.as_tensor(x) instead of the if statement often used
Throughout the API we often take parameters like kT, dt, and alpha. In some cases, these are typed as
float | torch.Tensor, in others they are justtorch.Tensor. We should:float | torch.Tensorin init and step functions.x = torch.as_tensor(x)instead of theifstatement often used