Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions csrc/amp_C_frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,16 @@ void multi_tensor_lamb_mp_cuda(
at::Tensor found_inf,
at::Tensor inv_scale);

at::Tensor update_scale_hysteresis_cuda(
at::Tensor current_scale,
at::Tensor growth_tracker,
at::Tensor hysteresis_tracker,
at::Tensor found_inf,
const double growth_factor,
const double backoff_factor,
const int64_t growth_interval,
const int hysteresis);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors");
Expand Down Expand Up @@ -211,4 +221,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Computes and apply update for LAMB optimizer");
m.def("multi_tensor_lamb_mp", &multi_tensor_lamb_mp_cuda,
"Computes and apply update for LAMB optimizer");
m.def("update_scale_hysteresis", &update_scale_hysteresis_cuda,
"Updates scale while accounting for hysteresis");
}
71 changes: 71 additions & 0 deletions csrc/update_scale_hysteresis.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#include <ATen/ATen.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/CUDAContext.h>

__global__ void update_scale_hysteresis_cuda_kernel(float* current_scale,
int* growth_tracker,
int* hysteresis_tracker,
const float* found_inf,
double growth_factor,
double backoff_factor,
int growth_interval,
int hysteresis)
{
if (*found_inf > 0) {
*hysteresis_tracker -= 1;

// Only reset the growth tracker when hysteresis is larger than zero
if (*hysteresis_tracker > 0) {
*growth_tracker = 0;
return;
}
}

if (*found_inf) {
*current_scale = (*current_scale)*backoff_factor;
*growth_tracker = 0;
} else {
// Entering this branch means we just carried out a successful step,
// so growth_tracker is incremented before comparing to growth_interval.
auto successful = (*growth_tracker) + 1;
if (successful == growth_interval) {
auto new_scale = static_cast<float>((*current_scale)*growth_factor);
// Do not grow the scale past fp32 bounds to inf.
if (isfinite(new_scale)) {
*current_scale = new_scale;
}
*growth_tracker = 0;
} else {
*growth_tracker = successful;
}
}

// Reset the hysteresis tracker if no infs are found
if (*found_inf <= 0) {
*hysteresis_tracker = hysteresis;
}
}

at::Tensor update_scale_hysteresis_cuda(at::Tensor current_scale,
at::Tensor growth_tracker,
at::Tensor hysteresis_tracker,
at::Tensor found_inf,
const double growth_factor,
const double backoff_factor,
const int64_t growth_interval,
const int hysteresis)
{
update_scale_hysteresis_cuda_kernel<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
current_scale.mutable_data_ptr<float>(),
growth_tracker.mutable_data_ptr<int>(),
hysteresis_tracker.mutable_data_ptr<int>(),
found_inf.const_data_ptr<float>(),
growth_factor,
backoff_factor,
growth_interval,
hysteresis);

AT_CUDA_CHECK(cudaGetLastError());

return current_scale;
}
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int
"csrc/multi_tensor_novograd.cu",
"csrc/multi_tensor_lamb.cu",
"csrc/multi_tensor_lamb_mp.cu",
"csrc/update_scale_hysteresis.cu",
],
extra_compile_args={
"cxx": ["-O3"] + version_dependent_macros,
Expand Down
102 changes: 102 additions & 0 deletions tests/L0/run_amp/test_update_scale_hysteresis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import unittest
import random
import math

import torch

try:
import amp_C
from amp_C import update_scale_hysteresis
disabled = False
except ImportError as err:
print("amp_C fused kernels unavailable, disabling TestUpdateScaleHysteresis. ImportError was ", err)
disabled = True

def isfinite(val):
return ((val >= torch.finfo(torch.float32).smallest_normal) and (val <= torch.finfo(torch.float32).max))

class TestUpdateScaleHysteresis(unittest.TestCase):

def setUp(self):
pass

def tearDown(self):
pass

def update_scale_hysteresis_body(self, init_scale, growth_factor, backoff_factor,
growth_interval, hysteresis):
scale_ref = float(init_scale)
grow_tracker_ref = 0
hysteresis_tracker_ref = 0

scale = torch.tensor([init_scale], dtype=torch.float32, device='cuda')
growth_tracker = torch.tensor([0], dtype=torch.int32, device='cuda')
hysteresis_tracker = torch.tensor([hysteresis], dtype=torch.int32, device='cuda')

# Infs appear for hysteresis-1 iterations, scale shouldn't change
found_inf = torch.tensor([1], dtype=torch.float32, device='cuda')
for i in range(hysteresis-1):
update_scale_hysteresis(scale, growth_tracker, hysteresis_tracker,
found_inf, growth_factor, backoff_factor, growth_interval, hysteresis)
self.assertTrue(scale.item() == init_scale)

# No infs for growth_interval-1 iterations, scale shouldn't change
found_inf.zero_()
for i in range(growth_interval-1):
update_scale_hysteresis(scale, growth_tracker, hysteresis_tracker,
found_inf, growth_factor, backoff_factor, growth_interval, hysteresis)
self.assertTrue(scale.item() == init_scale)

# Infs appear for more than hysteresis iterations, scale should be backed off
found_inf.fill_(1)
extra_iters = random.randint(0, 1000)
scale_before = scale.detach().item()
scale_ref = scale_before
for i in range(hysteresis + extra_iters):
update_scale_hysteresis(scale, growth_tracker, hysteresis_tracker,
found_inf, growth_factor, backoff_factor, growth_interval, hysteresis)
for i in range(1 + extra_iters):
# Scale is continuously backed off for each iteration with an inf
scale_new = scale_ref * backoff_factor
if isfinite(scale_new):
scale_ref = scale_new
else:
scale_ref = 0 # Scale update kernel does not check for underflow when backing off, which results in zero
self.assertTrue(scale.item() == scale_ref)

# No infs for more than growth_interval iterations, scale should be increased
found_inf.fill_(0)
extra_iters = random.randint(0, 1000)
scale_before = scale.detach().item()
scale_ref = scale_before
for i in range(growth_interval + extra_iters):
update_scale_hysteresis(scale, growth_tracker, hysteresis_tracker,
found_inf, growth_factor, backoff_factor, growth_interval, hysteresis)
for i in range(1 + int(math.floor(extra_iters / growth_interval))):
# Scale is grown every growth_interval iterations
scale_new = scale_ref * growth_factor
if isfinite(scale_new):
scale_ref = scale_new
self.assertTrue(scale.item() == scale_ref)


@unittest.skipIf(disabled, "amp_C is unavailable")
def test_fuzz(self):
init_scale_list = [1, 1024, 65536]
growth_factor_list = [1.0, 2.0, 4.0]
backoff_factor_list = [0.5, 0.25]
growth_interval_list = [10, 100]
hysteresis_list = [10, 100]

for init_scale in init_scale_list:
for growth_factor in growth_factor_list:
for backoff_factor in backoff_factor_list:
for growth_interval in growth_interval_list:
for hysteresis in hysteresis_list:
self.update_scale_hysteresis_body(init_scale, growth_factor,
backoff_factor, growth_interval, hysteresis)



if __name__ == '__main__':
unittest.main()