Skip to content

Refactor press implementation #21

@maxjeblick

Description

@maxjeblick

Feature

Separate press class into two separate classes

  • A scorer class that implements the .score method
  • A pruning class that implements the .forward_hook method

The press class then works with dependency injection, e.g., ExpectedAttentionPress can be expressed as

press = BasePruner(
        compression_ratio=compression_ratio,
        scorer=ExpectedAttentionScorer(
            n_future_positions=n_future_positions, n_sink=n_sink, use_covariance=use_covariance, use_vnorm=use_vnorm
        ),
    )

Motivation

Current press code couples forward hook and score method, making it harder to implement custom workflows.
By decoupling pruning and scoring functionality, it is possible to add new pruning methods by subclassing BasePruner rather than using a wrapper function (e.g. PerLayerCompressionPruner).

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions