Skip to content

feat: swiglu forward optimizations#63

Open
aghilann wants to merge 8 commits intoNVIDIA:mainfrom
aghilann:swiglu-optimizations
Open

feat: swiglu forward optimizations#63
aghilann wants to merge 8 commits intoNVIDIA:mainfrom
aghilann:swiglu-optimizations

Conversation

@aghilann
Copy link
Contributor

@aghilann aghilann commented Feb 23, 2026

Description

  • Implements a minimal, forward-only SwiGLU optimization in src/tilegym/ops/cutile/swiglu.py.
  • Uses fast sigmoid math flush_to_zero=True + approximate reciprocal via rounding_mode=RMd.APPROX to reduce scalar math cost.
  • Use gather/scatter instead of load-store
  • Preserves backward behavior while improving forward throughput.

Benchmark Results (Added bfloat16 + float32 in addition to float16)

Suite main CuTile (GB/s) swiglu-optimizations CuTile (GB/s) Speedup
swiglu-batch1-M128-bfloat16-GBps 1083.29 1723.81 1.591x
swiglu-batch1-M128-float16-GBps 1206.48 1741.01 1.443x
swiglu-batch1-M128-float32-GBps 1767.60 2330.74 1.319x
swiglu-batch1-M4096-bfloat16-GBps 1685.89 1877.48 1.114x
swiglu-batch1-M4096-float16-GBps 1593.83 1742.88 1.094x
swiglu-batch1-M4096-float32-GBps 1236.80 1252.18 1.012x
swiglu-batch4-M128-bfloat16-GBps 1919.65 2634.60 1.372x
swiglu-batch4-M128-float16-GBps 1987.82 2639.57 1.328x
swiglu-batch4-M128-float32-GBps 2008.17 2471.87 1.231x
swiglu-batch4-M4096-bfloat16-GBps 787.96 790.20 1.003x
swiglu-batch4-M4096-float16-GBps 787.30 791.15 1.005x
swiglu-batch4-M4096-float32-GBps 775.18 774.31 0.999x
swiglu-batch8-M128-bfloat16-GBps 2063.11 2593.24 1.257x
swiglu-batch8-M128-float16-GBps 2088.77 2598.10 1.244x
swiglu-batch8-M128-float32-GBps 1994.69 2197.19 1.102x
swiglu-batch8-M4096-bfloat16-GBps 774.35 776.40 1.003x
swiglu-batch8-M4096-float16-GBps 773.72 773.67 1.000x
swiglu-batch8-M4096-float32-GBps 773.00 773.09 1.000x
Overall (mean of suites) 1405.98 1693.42 1.204x

Notes for PR:

CI Configuration

config:
  build: true
  # valid options are "ops" and "benchmark"
  test: ["ops", "benchmark"]

Checklist

  • Code formatted and imports sorted via repo specifications (./format.sh)
  • Documentation updated (if needed)
  • CI configuration reviewed

@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 23, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.


# Compute sigmoid(a) and silu(a)
sigmoid_a = sigmoid(a_tile_f32)
sigmoid_a = 1.0 / (1.0 + ct.exp(-a_tile_f32))
Copy link
Contributor Author

@aghilann aghilann Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inlined this for now because I didn’t want to modify the backward kernel in this PR - that would require re-benchmarking it as well. I have additional optimizations planned that I’ll include in a separate PR, which will also make use of the new sigmoid implementation I added.

@aghilann aghilann force-pushed the swiglu-optimizations branch from 381e8dc to 0461595 Compare February 23, 2026 06:27
def sigmoid(x):
return 1.0 / (1.0 + ct.exp(-x))
denom = ct.add(1.0, ct.exp(-x), flush_to_zero=True)
return ct.truediv(1.0, denom, flush_to_zero=True, rounding_mode=RMd.APPROX)
Copy link
Contributor Author

@aghilann aghilann Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A good chunk of the savings came from Rmd.APPROX without losing precision - verified via tests

# Sigmoid requires type float32
c_tile = silu(a_tile.astype(ct.float32)).astype(a.dtype) * b_tile
ct.store(c, index=(row, col), tile=c_tile)
a_tile = ct.gather(a, (row, offsets), check_bounds=True, padding_value=0.0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good chunk of the perf improvements came from gather scatter vs load/store

create_benchmark_config(batch_size, M, dtype)
for batch_size in [1, 4, 8] # Different batch sizes
for M in [128, 4096] # Different rows
for dtype in [torch.float16, torch.bfloat16, torch.float32]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most benchmarks test across various dtypes, I thought this one should too

@aghilann
Copy link
Contributor Author

Hey @hannahli-nv, another day - another cuTILE perf upgrade!

@aghilann
Copy link
Contributor Author

@xjmxyt Any chance I could get a review :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant