diff --git a/src/sc2reader_plugins/apm_tracker.py b/src/sc2reader_plugins/apm_tracker.py index 4e34fd1..a8c3eea 100644 --- a/src/sc2reader_plugins/apm_tracker.py +++ b/src/sc2reader_plugins/apm_tracker.py @@ -32,14 +32,17 @@ def handleInitGame(self, event: "Event", replay: "Replay"): for player in replay.players: player.official_apm = None else: + archon_mode = any(p.archon_leader_id is not None for p in replay.players) + for player in replay.players: # players have pid starting from 1, and there may be observers that # have pid larger than the number of players, so we need to check # if the pid is valid try: - # pid starts from 1, but index starts from 0 - player.official_apm = gamemetadata["Players"][player.pid - 1]["APM"] - except KeyError: + # pid and team id starts from 1, but index starts from 0 + p_index = (player.pid if not archon_mode else player.team_id) - 1 + player.official_apm = gamemetadata["Players"][p_index]["APM"] + except (KeyError, IndexError): player.official_apm = None # build self-calculated apm and aps for player in replay.players: diff --git a/tests/replays/2v2_archon_mode.SC2Replay b/tests/replays/2v2_archon_mode.SC2Replay new file mode 100644 index 0000000..49db3f1 Binary files /dev/null and b/tests/replays/2v2_archon_mode.SC2Replay differ diff --git a/tests/test_apm_tracker.py b/tests/test_apm_tracker.py index 29d5622..e6b345f 100644 --- a/tests/test_apm_tracker.py +++ b/tests/test_apm_tracker.py @@ -92,3 +92,30 @@ def test_2v2(): assert p2.avg_apm > 10 assert p3.avg_apm > 10 assert p4.avg_apm > 10 + + +def test_2v2_archon_mode(): + factory = sc2reader.factories.SC2Factory() + engine = sc2reader.engine.GameEngine( + plugins=[EventSecondCorrector(), ContextLoader(), APMTracker()] + ) + replay = factory.load_replay( + "tests/replays/2v2_archon_mode.SC2Replay", engine=engine + ) + t1, t2 = replay.teams + t1p1, t1p2 = t1.players + t2p1, t2p2 = t2.players + assert t1p1.official_apm > 10 + assert t1p2.official_apm > 10 + assert t2p1.official_apm > 10 + assert t2p2.official_apm > 10 + assert t1p1.official_apm == t1p2.official_apm + assert t2p1.official_apm == t2p2.official_apm + assert len(t1p1.apm) > 0 + assert len(t1p2.apm) > 0 + assert len(t2p1.apm) > 0 + assert len(t2p2.apm) > 0 + assert t1p1.avg_apm > 10 + assert t1p2.avg_apm > 10 + assert t2p1.avg_apm > 10 + assert t2p2.avg_apm > 10