From 328871f240276fdd6a86cd29e51b9a6dbde1cb1c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 20 Aug 2024 11:46:56 +0200 Subject: [PATCH] add nx --- src/transformers/pytorch_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index 50f7d6d862f7..4c74a04d4f34 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -96,6 +96,7 @@ class Conv1D(nn.Module): def __init__(self, nf, nx): super().__init__() self.nf = nf + self.nx = nx self.weight = nn.Parameter(torch.empty(nx, nf)) self.bias = nn.Parameter(torch.zeros(nf)) nn.init.normal_(self.weight, std=0.02)