From 883cb192f741c6f81134ee676ab0e941d2d5fe9a Mon Sep 17 00:00:00 2001 From: Aron Jansen Date: Mon, 10 Jul 2023 13:05:09 +0200 Subject: [PATCH 1/2] Implement FkRotation as subclass of Rotation by rewriting using a rotation tensor --- n3fit/src/n3fit/layers/rotations.py | 53 +++++++++++++---------------- 1 file changed, 24 insertions(+), 29 deletions(-) diff --git a/n3fit/src/n3fit/layers/rotations.py b/n3fit/src/n3fit/layers/rotations.py index 686d3c5bbb..6001a312bc 100644 --- a/n3fit/src/n3fit/layers/rotations.py +++ b/n3fit/src/n3fit/layers/rotations.py @@ -56,7 +56,7 @@ def __init__( super().__init__(rotation_matrix, axes=1, **kwargs) -class FkRotation(MetaLayer): +class FkRotation(Rotation): """ Applies a transformation from the dimension-8 evolution basis to the dimension-14 evolution basis used by the fktables. @@ -64,36 +64,31 @@ class FkRotation(MetaLayer): The input to this layer is a `pdf_raw` variable which is expected to have a shape (1, None, 8), and it is then rotated to an output (1, None, 14) """ - - # TODO: Generate a rotation matrix in the input and just do tf.tensordot in call - # the matrix should be: (8, 14) so that we can just do tf.tensordot(pdf, rotmat, axes=1) - # i.e., create the matrix and inherit from the Rotation layer above def __init__(self, output_dim=14, name="evolution", **kwargs): self.output_dim = output_dim - super().__init__(name=name, **kwargs) - - def call(self, pdf_raw): - # Transpose the PDF so that the flavour index is the first one - x = op.transpose(pdf_raw) - pdf_raw_list = [ - 0 * x[0], # photon - x[0], # sigma - x[1], # g - x[2], # v - x[3], # v3 - x[4], # v8 - x[8], # v15 - x[2], # v24 - x[2], # v35 - x[5], # t3 - x[6], # t8 - x[0] - 4 * x[7], # t15 (c-) - x[0], # t24 - x[0], # t35 - ] - ret = op.concatenate(pdf_raw_list) - # Concatenating destroys the batch index so we have to regenerate it - return op.batchit(ret) + rotation_matrix = self._create_rotation_matrix() + super().__init__(rotation_matrix, axes=1, name=name, **kwargs) + + def _create_rotation_matrix(self): + """Create the rotation matrix""" + array = np.array([ + [0, 0, 0, 0, 0, 0, 0, 0, 0], # photon + [1, 0, 0, 0, 0, 0, 0, 0, 0], # sigma + [0, 1, 0, 0, 0, 0, 0, 0, 0], # g + [0, 0, 1, 0, 0, 0, 0, 0, 0], # v + [0, 0, 0, 1, 0, 0, 0, 0, 0], # v3 + [0, 0, 0, 0, 1, 0, 0, 0, 0], # v8 + [0, 0, 0, 0, 0, 0, 0, 0, 1], # v15 + [0, 0, 1, 0, 0, 0, 0, 0, 0], # v24 + [0, 0, 1, 0, 0, 0, 0, 0, 0], # v35 + [0, 0, 0, 0, 0, 1, 0, 0, 0], # t3 + [0, 0, 0, 0, 0, 0, 1, 0, 0], # t8 + [1, 0, 0, 0, 0, 0, 0,-4, 0], # t15 (c-) + [1, 0, 0, 0, 0, 0, 0, 0, 0], # t24 + [1, 0, 0, 0, 0, 0, 0, 0, 0], # t35 + ]) + tensor = op.numpy_to_tensor(array.T) + return tensor class AddPhoton(MetaLayer): From e7a91405a59b13b5662541d38dc48e887e99c5e4 Mon Sep 17 00:00:00 2001 From: Aron Jansen Date: Mon, 10 Jul 2023 13:56:52 +0200 Subject: [PATCH 2/2] Change 8 to 9 in docstring --- n3fit/src/n3fit/layers/rotations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/n3fit/src/n3fit/layers/rotations.py b/n3fit/src/n3fit/layers/rotations.py index 6001a312bc..ed58cdbe15 100644 --- a/n3fit/src/n3fit/layers/rotations.py +++ b/n3fit/src/n3fit/layers/rotations.py @@ -58,11 +58,11 @@ def __init__( class FkRotation(Rotation): """ - Applies a transformation from the dimension-8 evolution basis + Applies a transformation from the dimension-9 evolution basis to the dimension-14 evolution basis used by the fktables. The input to this layer is a `pdf_raw` variable which is expected to have - a shape (1, None, 8), and it is then rotated to an output (1, None, 14) + a shape (1, None, 9), and it is then rotated to an output (1, None, 14) """ def __init__(self, output_dim=14, name="evolution", **kwargs): self.output_dim = output_dim