@@ -608,7 +608,7 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli
608608 else :
609609 pipelines .append (FmhaFwdPipeline ('qr_async' , 'row' , 't' , 'f' , 't' , 't' , logits , bias , lse , dropout , squant , mask , skip , 'f' ))
610610 pipelines .append (FmhaFwdPipeline ('qr_async' , 'row' , 't' , 't' , 't' , 't' , logits , bias , lse , dropout , squant , mask , skip , 'f' ))
611- if (hdim , hdim_v ) in [(64 , 64 ), (128 , 128 )] and logits == "f" and bias == "no" and dropout == "f" and skip == "f" :
611+ if (hdim , hdim_v ) in [(64 , 64 ), (128 , 128 )] and logits == "f" and bias == "no" and dropout == "f" and lse == "f" and skip == "f" :
612612 pipelines .append (FmhaFwdPipeline ('qr_async_trload' , 'row' , 'f' , 'f' , 'f' , 'f' , logits , bias , lse , dropout , squant , mask , skip , 't' ))
613613 pipelines .append (FmhaFwdPipeline ('qr_async_trload' , 'row' , 'f' , 'f' , 't' , 't' , logits , bias , lse , dropout , squant , mask , skip , 't' ))
614614 if receipt == 1 and bias != "bias" :
0 commit comments