Skip to content

Commit 97a1828

Browse files
Fix mypy type errors in transaction_manager.py and transactions.py
1 parent 570055e commit 97a1828

2 files changed

Lines changed: 24 additions & 13 deletions

File tree

codegen-on-oss/codegen_on_oss/analyzers/transaction_manager.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,9 @@ def __init__(self) -> None:
6161
self.pending_undos: Set[Callable[[], None]] = set()
6262
self._commiting: bool = False
6363
self.max_transactions: Optional[int] = None # None = no limit
64-
self.stopwatch_start = None
64+
self.stopwatch_start: Optional[float] = None
6565
self.stopwatch_max_seconds: Optional[int] = None # None = no limit
66-
self.session = {} # Session data for tracking state
66+
self.session: Dict[str, Any] = {} # Session data for tracking state
6767

6868
def sort_transactions(self) -> None:
6969
"""Sort transactions by priority and position."""
@@ -127,7 +127,7 @@ def reset_stopwatch(self, max_seconds: Optional[int] = None) -> None:
127127

128128
def is_time_exceeded(self) -> bool:
129129
"""Check if the stopwatch time limit has been exceeded."""
130-
if self.stopwatch_max_seconds is None:
130+
if self.stopwatch_max_seconds is None or self.stopwatch_start is None:
131131
return False
132132
else:
133133
num_seconds = time.time() - self.stopwatch_start
@@ -384,7 +384,7 @@ def get_transactions_at_range(self, file_path: Path, start_byte: int, end_byte:
384384
Returns:
385385
List of matching transactions
386386
"""
387-
matching_transactions = []
387+
matching_transactions: List[Transaction] = []
388388
if file_path not in self.queued_transactions:
389389
return matching_transactions
390390

@@ -435,14 +435,24 @@ def _get_conflicts(self, transaction: Transaction) -> List[Transaction]:
435435
Returns:
436436
List of conflicting transactions
437437
"""
438-
overlapping_transactions = []
438+
overlapping_transactions: List[Transaction] = []
439439
if transaction.file_path not in self.queued_transactions:
440440
return overlapping_transactions
441-
442-
queued_transactions = list(self.queued_transactions[transaction.file_path])
443-
for t in queued_transactions:
444-
if transaction.start_byte < t.end_byte and transaction.end_byte > t.start_byte:
441+
442+
for t in self.queued_transactions[transaction.file_path]:
443+
# Skip if it's the same transaction
444+
if t == transaction:
445+
continue
446+
447+
# Check if the transactions overlap
448+
if (
449+
(t.start_byte <= transaction.start_byte < t.end_byte)
450+
or (t.start_byte < transaction.end_byte <= t.end_byte)
451+
or (transaction.start_byte <= t.start_byte < transaction.end_byte)
452+
or (transaction.start_byte < t.end_byte <= transaction.end_byte)
453+
):
445454
overlapping_transactions.append(t)
455+
446456
return overlapping_transactions
447457

448458
def _get_overlapping_conflicts(self, transaction: Transaction) -> Optional[Transaction]:
@@ -456,9 +466,8 @@ def _get_overlapping_conflicts(self, transaction: Transaction) -> Optional[Trans
456466
"""
457467
if transaction.file_path not in self.queued_transactions:
458468
return None
459-
469+
460470
for t in self.queued_transactions[transaction.file_path]:
461471
if transaction.start_byte >= t.start_byte and transaction.end_byte <= t.end_byte:
462472
return t
463473
return None
464-

codegen-on-oss/codegen_on_oss/analyzers/transactions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,8 @@ def __init__(
210210

211211
def _generate_new_content_bytes(self) -> bytes:
212212
"""Generate the new content bytes after insertion."""
213+
if self.new_content is None:
214+
raise ValueError("Cannot generate content bytes: new_content is None")
213215
new_bytes = bytes(self.new_content, encoding="utf-8")
214216
content_bytes = self.file.content_bytes
215217
head = content_bytes[: self.insert_byte]
@@ -230,7 +232,8 @@ def get_diff(self) -> DiffLite:
230232
def diff_str(self) -> str:
231233
"""Human-readable string representation of the change."""
232234
diff = "".join(unified_diff(self.file.content.splitlines(True), self._generate_new_content_bytes().decode("utf-8").splitlines(True)))
233-
return f"Insert {len(self.new_content)} bytes at bytes ({self.start_byte}, {self.end_byte})\n{diff}"
235+
content_length = len(self.new_content) if self.new_content is not None else 0
236+
return f"Insert {content_length} bytes at bytes ({self.start_byte}, {self.end_byte})\n{diff}"
234237

235238
class EditTransaction(Transaction):
236239
"""Transaction to edit content in a file."""
@@ -364,4 +367,3 @@ def get_diff(self) -> DiffLite:
364367
def diff_str(self) -> str:
365368
"""Human-readable string representation of the change."""
366369
return f"Remove file at {self.file_path}"
367-

0 commit comments

Comments
 (0)