This is a PyTorch implementation of FlashSR.
If you find this repository helpful, please consider citing it.
@article{im2025flashsr,
title={FlashSR: One-step Versatile Audio Super-resolution via Diffusion Distillation},
author={Im, Jaekwon and Nam, Juhan},
journal={arXiv preprint arXiv:2501.10807},
year={2025}
}git clone git@github.com:jakeoneijk/FlashSR_Inference.git
cd FlashSR_Inference
source ./Script/0_conda_env_setup.sh
source ./Script/1_pytorch_install.sh
source ./Script/2_setup.sh
After installation, you can import the module from anywhere
from FlashSR.FlashSR import FlashSR
student_ldm_ckpt_path:str = './ModelWeights/student_ldm.pth'
sr_vocoder_ckpt_path:str = './ModelWeights/sr_vocoder.pth'
vae_ckpt_path:str = './ModelWeights/vae.pth'
flashsr = FlashSR( student_ldm_ckpt_path, sr_vocoder_ckpt_path, vae_ckpt_path)Read the low-resolution audio
audio_path:str = './Assets/ExampleInput/music.wav'
# resample audio to 48kHz
# audio.shape = [channel_size, time] ex) [2, 245760]
audio, sr = UtilAudio.read(audio_path, sample_rate = 48000)
# currently, the model only supports 245760 samples (5.12 seconds of audio)
audio = UtilData.fix_length(audio, 245760)
audio = audio.to(device)Restore high-frequency components by FlashSR
# lowpass_input: if True, apply lowpass filter to input audio before super resolution. This can help reduce discrepancy between training data and inference data.
pred_hr_audio = flashsr(audio, lowpass_input = False)
UtilAudio.write('./output.wav', pred_hr_audio, 48000)