diff --git a/n3fit/src/n3fit/layers/rotations.py b/n3fit/src/n3fit/layers/rotations.py index 686d3c5bbb..ed58cdbe15 100644 --- a/n3fit/src/n3fit/layers/rotations.py +++ b/n3fit/src/n3fit/layers/rotations.py @@ -56,44 +56,39 @@ def __init__( super().__init__(rotation_matrix, axes=1, **kwargs) -class FkRotation(MetaLayer): +class FkRotation(Rotation): """ - Applies a transformation from the dimension-8 evolution basis + Applies a transformation from the dimension-9 evolution basis to the dimension-14 evolution basis used by the fktables. The input to this layer is a `pdf_raw` variable which is expected to have - a shape (1, None, 8), and it is then rotated to an output (1, None, 14) + a shape (1, None, 9), and it is then rotated to an output (1, None, 14) """ - - # TODO: Generate a rotation matrix in the input and just do tf.tensordot in call - # the matrix should be: (8, 14) so that we can just do tf.tensordot(pdf, rotmat, axes=1) - # i.e., create the matrix and inherit from the Rotation layer above def __init__(self, output_dim=14, name="evolution", **kwargs): self.output_dim = output_dim - super().__init__(name=name, **kwargs) - - def call(self, pdf_raw): - # Transpose the PDF so that the flavour index is the first one - x = op.transpose(pdf_raw) - pdf_raw_list = [ - 0 * x[0], # photon - x[0], # sigma - x[1], # g - x[2], # v - x[3], # v3 - x[4], # v8 - x[8], # v15 - x[2], # v24 - x[2], # v35 - x[5], # t3 - x[6], # t8 - x[0] - 4 * x[7], # t15 (c-) - x[0], # t24 - x[0], # t35 - ] - ret = op.concatenate(pdf_raw_list) - # Concatenating destroys the batch index so we have to regenerate it - return op.batchit(ret) + rotation_matrix = self._create_rotation_matrix() + super().__init__(rotation_matrix, axes=1, name=name, **kwargs) + + def _create_rotation_matrix(self): + """Create the rotation matrix""" + array = np.array([ + [0, 0, 0, 0, 0, 0, 0, 0, 0], # photon + [1, 0, 0, 0, 0, 0, 0, 0, 0], # sigma + [0, 1, 0, 0, 0, 0, 0, 0, 0], # g + [0, 0, 1, 0, 0, 0, 0, 0, 0], # v + [0, 0, 0, 1, 0, 0, 0, 0, 0], # v3 + [0, 0, 0, 0, 1, 0, 0, 0, 0], # v8 + [0, 0, 0, 0, 0, 0, 0, 0, 1], # v15 + [0, 0, 1, 0, 0, 0, 0, 0, 0], # v24 + [0, 0, 1, 0, 0, 0, 0, 0, 0], # v35 + [0, 0, 0, 0, 0, 1, 0, 0, 0], # t3 + [0, 0, 0, 0, 0, 0, 1, 0, 0], # t8 + [1, 0, 0, 0, 0, 0, 0,-4, 0], # t15 (c-) + [1, 0, 0, 0, 0, 0, 0, 0, 0], # t24 + [1, 0, 0, 0, 0, 0, 0, 0, 0], # t35 + ]) + tensor = op.numpy_to_tensor(array.T) + return tensor class AddPhoton(MetaLayer):