Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 43 additions & 8 deletions src/diffusers/schedulers/scheduling_ipndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,18 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
trained_betas (`np.ndarray`, *optional*):
trained_betas (`np.ndarray` or `List[float]`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
"""

order = 1

@register_to_config
def __init__(self, num_train_timesteps: int = 1000, trained_betas: np.ndarray | list[float] | None = None):
def __init__(
self,
num_train_timesteps: int = 1000,
trained_betas: np.ndarray | list[float] | None = None,
):
# set `betas`, `alphas`, `timesteps`
self.set_timesteps(num_train_timesteps)

Expand All @@ -56,21 +60,29 @@ def __init__(self, num_train_timesteps: int = 1000, trained_betas: np.ndarray |
self._begin_index = None

@property
def step_index(self):
def step_index(self) -> int | None:
"""
The index counter for current timestep. It will increase 1 after each scheduler step.

Returns:
`int` or `None`:
The index counter for current timestep.
"""
return self._step_index

@property
def begin_index(self):
def begin_index(self) -> int | None:
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.

Returns:
`int` or `None`:
The index for the first timestep.
"""
return self._begin_index

# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
def set_begin_index(self, begin_index: int = 0) -> None:
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

Expand Down Expand Up @@ -169,7 +181,7 @@ def step(
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
timestep (`int` or `torch.Tensor`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Expand Down Expand Up @@ -228,7 +240,30 @@ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tens
"""
return sample

def _get_prev_sample(self, sample, timestep_index, prev_timestep_index, ets):
def _get_prev_sample(
self,
sample: torch.Tensor,
timestep_index: int,
prev_timestep_index: int,
ets: torch.Tensor,
) -> torch.Tensor:
"""
Predicts the previous sample based on the current sample, timestep indices, and running model outputs.

Args:
sample (`torch.Tensor`):
The current sample.
timestep_index (`int`):
Index of the current timestep in the schedule.
prev_timestep_index (`int`):
Index of the previous timestep in the schedule.
ets (`torch.Tensor`):
The running sequence of model outputs.

Returns:
`torch.Tensor`:
The predicted previous sample.
"""
alpha = self.alphas[timestep_index]
sigma = self.betas[timestep_index]

Expand All @@ -240,5 +275,5 @@ def _get_prev_sample(self, sample, timestep_index, prev_timestep_index, ets):

return prev_sample

def __len__(self):
def __len__(self) -> int:
return self.config.num_train_timesteps
Loading