Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 26 additions & 31 deletions n3fit/src/n3fit/layers/rotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,44 +56,39 @@ def __init__(
super().__init__(rotation_matrix, axes=1, **kwargs)


class FkRotation(MetaLayer):
class FkRotation(Rotation):
"""
Applies a transformation from the dimension-8 evolution basis
Applies a transformation from the dimension-9 evolution basis
to the dimension-14 evolution basis used by the fktables.

The input to this layer is a `pdf_raw` variable which is expected to have
a shape (1, None, 8), and it is then rotated to an output (1, None, 14)
a shape (1, None, 9), and it is then rotated to an output (1, None, 14)
"""

# TODO: Generate a rotation matrix in the input and just do tf.tensordot in call
# the matrix should be: (8, 14) so that we can just do tf.tensordot(pdf, rotmat, axes=1)
# i.e., create the matrix and inherit from the Rotation layer above
def __init__(self, output_dim=14, name="evolution", **kwargs):
self.output_dim = output_dim
super().__init__(name=name, **kwargs)

def call(self, pdf_raw):
# Transpose the PDF so that the flavour index is the first one
x = op.transpose(pdf_raw)
pdf_raw_list = [
0 * x[0], # photon
x[0], # sigma
x[1], # g
x[2], # v
x[3], # v3
x[4], # v8
x[8], # v15
x[2], # v24
x[2], # v35
x[5], # t3
x[6], # t8
x[0] - 4 * x[7], # t15 (c-)
x[0], # t24
x[0], # t35
]
ret = op.concatenate(pdf_raw_list)
# Concatenating destroys the batch index so we have to regenerate it
return op.batchit(ret)
rotation_matrix = self._create_rotation_matrix()
super().__init__(rotation_matrix, axes=1, name=name, **kwargs)

def _create_rotation_matrix(self):
"""Create the rotation matrix"""
array = np.array([
[0, 0, 0, 0, 0, 0, 0, 0, 0], # photon
[1, 0, 0, 0, 0, 0, 0, 0, 0], # sigma
[0, 1, 0, 0, 0, 0, 0, 0, 0], # g
[0, 0, 1, 0, 0, 0, 0, 0, 0], # v
[0, 0, 0, 1, 0, 0, 0, 0, 0], # v3
[0, 0, 0, 0, 1, 0, 0, 0, 0], # v8
[0, 0, 0, 0, 0, 0, 0, 0, 1], # v15
[0, 0, 1, 0, 0, 0, 0, 0, 0], # v24
[0, 0, 1, 0, 0, 0, 0, 0, 0], # v35
[0, 0, 0, 0, 0, 1, 0, 0, 0], # t3
[0, 0, 0, 0, 0, 0, 1, 0, 0], # t8
[1, 0, 0, 0, 0, 0, 0,-4, 0], # t15 (c-)
[1, 0, 0, 0, 0, 0, 0, 0, 0], # t24
[1, 0, 0, 0, 0, 0, 0, 0, 0], # t35
])
tensor = op.numpy_to_tensor(array.T)
return tensor


class AddPhoton(MetaLayer):
Expand Down