Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 37 additions & 37 deletions applications/Chat/coati/experience_buffer/utils.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,74 @@
from dataclasses import dataclass
from typing import List, Optional
from typing import List

import torch
import torch.nn.functional as F
from coati.experience_maker.base import Experience


@dataclass
class BufferItem:
"""BufferItem is an item of experience data.
"""BufferItem is an item of `Experience` data.

Shapes of each tensor:
sequences: (S)
action_log_probs: (A)
values: (1)
reward: (1)
advantages: (1)
attention_mask: (S)
action_mask: (A)
sequences: (S)
attention_mask: (S)
action_mask: (A)
step_mask: (N)
action_log_probs: (A)
values: (N)
returns: (N)
advantages: (N)

"A" is the number of actions.
"""

sequences: torch.Tensor
attention_mask: torch.LongTensor
action_mask: torch.BoolTensor
step_mask: torch.BoolTensor
action_log_probs: torch.Tensor
values: torch.Tensor
reward: torch.Tensor
returns: torch.Tensor
advantages: torch.Tensor
attention_mask: Optional[torch.LongTensor]
action_mask: Optional[torch.BoolTensor]


def split_experience_batch(experience: Experience) -> List[BufferItem]:
batch_size = experience.sequences.size(0)
batch_kwargs = [{} for _ in range(batch_size)]
keys = ("sequences", "action_log_probs", "values", "reward", "advantages", "attention_mask", "action_mask")
keys = (
"sequences",
"attention_mask",
"action_mask",
"step_mask",
"action_log_probs",
"values",
"returns",
"advantages",
)
for key in keys:
value = getattr(experience, key)
if isinstance(value, torch.Tensor):
vals = torch.unbind(value)
else:
# None
vals = [value for _ in range(batch_size)]
assert isinstance(value, torch.Tensor)
vals = torch.unbind(value)
assert batch_size == len(vals)
for i, v in enumerate(vals):
batch_kwargs[i][key] = v
items = [BufferItem(**kwargs) for kwargs in batch_kwargs]
return items


def _zero_pad_sequences(sequences: List[torch.Tensor], side: str = "left") -> torch.Tensor:
assert side in ("left", "right")
max_len = max(seq.size(0) for seq in sequences)
padded_sequences = []
for seq in sequences:
pad_len = max_len - seq.size(0)
padding = (pad_len, 0) if side == "left" else (0, pad_len)
padded_sequences.append(F.pad(seq, padding))
return torch.stack(padded_sequences, dim=0)


def make_experience_batch(items: List[BufferItem]) -> Experience:
kwargs = {}
to_pad_keys = set(("action_log_probs", "action_mask"))
keys = ("sequences", "action_log_probs", "values", "reward", "advantages", "attention_mask", "action_mask")
keys = (
"sequences",
"attention_mask",
"action_mask",
"step_mask",
"action_log_probs",
"values",
"returns",
"advantages",
)
for key in keys:
vals = [getattr(item, key) for item in items]
if key in to_pad_keys:
batch_data = _zero_pad_sequences(vals)
else:
batch_data = torch.stack(vals, dim=0)
batch_data = torch.stack(vals, dim=0)
kwargs[key] = batch_data
return Experience(**kwargs)
8 changes: 6 additions & 2 deletions applications/Chat/coati/experience_maker/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from .base import Experience, ExperienceMaker
from .naive import NaiveExperienceMaker
from .chunked import ChunkedExperienceMaker

__all__ = ["Experience", "ExperienceMaker", "NaiveExperienceMaker"]
__all__ = [
"Experience",
"ExperienceMaker",
"ChunkedExperienceMaker",
]
60 changes: 37 additions & 23 deletions applications/Chat/coati/experience_maker/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional

import torch
from coati.models.base import Actor, Critic, RewardModel
Expand All @@ -9,51 +8,66 @@
@dataclass
class Experience:
"""Experience is a batch of data.
These data should have the sequence length and number of actions.
Left padding for sequences is applied.

"B" is the batch size.
"S" is the sequence length.
"A" is the number of actions.
"C" is the chunk size.
"N" is the number of MDP steps.
NOTE: N = A / C, each Experience contains N MDP steps ([s0, a0], [s1, a1], ...),
sequences = |pad|prompt|a0|a1|a2|...|pad|,
s0 = prompt, s1 = prompt + a0, s2 = prompt + a0 + a1, ...
FIXME(cwher): store N steps in a Experience can be computationally efficient,
but may be different from uniform sampling (shuffle all steps and sample).

Shapes of each tensor:
sequences: (B, S)
action_log_probs: (B, A)
values: (B)
reward: (B)
advantages: (B)
attention_mask: (B, S)
action_mask: (B, A)
sequences: (B, S)
attention_mask: (B, S)
action_mask: (B, A)
step_mask: (B, N)
action_log_probs: (B, A)
values: (B, N), output of old critic model
returns: (B, N), result of GAE
advantages: (B, N), result of GAE

"A" is the number of actions.
e.g.,
sequences = |pad|prompt|response|pad|
attention_mask = |0|1|1|0|
action_mask = |1|0| (for response)

NOTE: `Experience` are split into `BufferItem`s when added to buffer.
"""

sequences: torch.Tensor
attention_mask: torch.LongTensor
action_mask: torch.BoolTensor
step_mask: torch.BoolTensor
action_log_probs: torch.Tensor
values: torch.Tensor
reward: torch.Tensor
returns: torch.Tensor
advantages: torch.Tensor
attention_mask: Optional[torch.LongTensor]
action_mask: Optional[torch.BoolTensor]

@torch.no_grad()
def to_device(self, device: torch.device) -> None:
self.sequences = self.sequences.to(device)
self.attention_mask = self.attention_mask.to(device)
self.action_mask = self.action_mask.to(device)
self.step_mask = self.step_mask.to(device)
self.action_log_probs = self.action_log_probs.to(device)
self.values = self.values.to(device)
self.reward = self.reward.to(device)
self.returns = self.returns.to(device)
self.advantages = self.advantages.to(device)
if self.attention_mask is not None:
self.attention_mask = self.attention_mask.to(device)
if self.action_mask is not None:
self.action_mask = self.action_mask.to(device)

def pin_memory(self):
self.sequences = self.sequences.pin_memory()
self.attention_mask = self.attention_mask.pin_memory()
self.action_mask = self.action_mask.pin_memory()
self.step_mask = self.step_mask.pin_memory()
self.action_log_probs = self.action_log_probs.pin_memory()
self.values = self.values.pin_memory()
self.reward = self.reward.pin_memory()
self.returns = self.returns.pin_memory()
self.advantages = self.advantages.pin_memory()
if self.attention_mask is not None:
self.attention_mask = self.attention_mask.pin_memory()
if self.action_mask is not None:
self.action_mask = self.action_mask.pin_memory()
return self


Expand Down
Loading