Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 46 additions & 36 deletions config_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
"""

from __future__ import annotations
from typing import (Dict, List, ClassVar, Union, Any, Optional, TYPE_CHECKING)
from typing import (List, ClassVar, Union, Optional, TYPE_CHECKING)

from dataclasses import dataclass, asdict, fields, InitVar
from abc import ABC, abstractmethod
import re

if TYPE_CHECKING:
from drmaa import JobTemplate
Expand All @@ -19,6 +20,8 @@
"remoteCommand", "args", "jobName", "jobCategory", "blockEmail"
]

TIMESTR_VALIDATE = re.compile("^(\\d+:)?[0-9][0-9]:[0-9][0-9]$")


@dataclass
class DRMAACompatible(ABC):
Expand All @@ -33,7 +36,7 @@ class DRMAACompatible(ABC):
the corresponding native specification
'''

_mapped_fields: ClassVar[Dict[str, Any]]
_mapped_fields: ClassVar[List[str]]

def __str__(self):
'''
Expand Down Expand Up @@ -72,35 +75,21 @@ def drm2drmaa(self) -> str:
Build native specification from DRM-specific fields
'''

def _map_fields(self, **drm_kwargs: Dict[str, Any]):
'''
Transform fields in `_mapped_fields` to
DRMAA-compliant specification. Adds
DRM-specific attributes to `self`

Arguments:
drm_kwargs: DRM-specific key-value pairs
'''
for drm_name, value in drm_kwargs.items():
try:
drmaa_name = self._mapped_fields[drm_name]
except KeyError:
raise AttributeError(
"Malformed adapter class! Cannot map field"
f"{drm_name} to a DRMAA-compliant field")

setattr(self, drmaa_name, value)

def __post_init__(self, **kwargs):
self._map_fields(**kwargs)

def _native_fields(self):
return [
f for f in asdict(self).keys()
if (f not in self._mapped_fields.keys()) and (
f not in DRMAA_FIELDS)
if (f not in self._mapped_fields) and (f not in DRMAA_FIELDS)
]

def set_fields(self, **drmaa_kwargs):
for field, value in drmaa_kwargs.items():
if field not in DRMAA_FIELDS:
raise AttributeError(
"Malformed adapter class! Cannot map field"
f" {field} to a DRMAA-compliant field")

setattr(self, field, value)


@dataclass
class DRMAAConfig(DRMAACompatible):
Expand All @@ -118,11 +107,8 @@ class SlurmConfig(DRMAACompatible):
details
'''

_mapped_fields: ClassVar[Dict[str, Any]] = {
"error": "errorPath",
"output": "outputPath",
"job_name": "jobName",
"time": "hardWallclockTimeLimit"
_mapped_fields: ClassVar[List[str]] = {
"error", "output", "job_name", "time"
}

job_name: InitVar[str]
Expand Down Expand Up @@ -161,13 +147,18 @@ class SlurmConfig(DRMAACompatible):
def __post_init__(self, job_name, time, error, output):
'''
Transform Union[List[str]] --> comma-delimited str
In addition map time to seconds
'''

super().__post_init__(job_name=job_name,
time=_timestr_to_sec(time),
error=error,
output=output)
_validate_timestr(time, "time")
super().set_fields(jobName=job_name,
hardWallclockTimeLimit=time,
errorPath=error,
outputPath=output)

self.job_name = job_name
self.time = time
self.error = error
self.output = output

for field in fields(self):
value = getattr(self, field.name)
Expand Down Expand Up @@ -214,3 +205,22 @@ def _timestr_to_sec(timestr: str) -> int:
seconds += int(unit) * (60**exp)

return seconds


def _validate_timestr(timestr: str, field_name: str) -> str:
'''
Validate timestring to make sure it meets
expected format.
'''

if not isinstance(timestr, str):
raise TypeError(f"Expected {field_name} to be of type string "
f"but received {type(timestr)}!")

result = TIMESTR_VALIDATE.match(timestr)
if not result:
raise ValueError(f"Expected {field_name} to be of format "
"X...XX:XX:XX or XX:XX! "
f"but received {timestr}")

return timestr
11 changes: 10 additions & 1 deletion drmaa_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
from drmaa.helpers import Attribute, IntConverter


#TODO: Make sure this is actually correct?
# Works for SLURM
CORRECT_TO_STRING = [
"hardWallclockTimeLimit"
]


class PatchedIntConverter():
'''
Helper class to correctly encode Integer values
Expand Down Expand Up @@ -33,7 +40,9 @@ def __init__(self):
super(PatchedJobTemplate, self).__init__()
for attr, value in vars(JobTemplate).items():
if isinstance(value, Attribute):
if value.converter is IntConverter:
if attr in CORRECT_TO_STRING:
setattr(value, "converter", None)
elif value.converter is IntConverter:
setattr(value, "converter", PatchedIntConverter)


Expand Down
32 changes: 29 additions & 3 deletions tests/test_config_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ def test_slurm_config_transforms_to_drmaa(job_template):

error = "TEST_VALUE"
output = "TEST_VALUE"
time = "10:00:00" # must test as seconds
time = "10:00:00"
job_name = "FAKE_JOB"

expected_drmaa_attrs = {
"errorPath": error,
"outputPath": output,
"hardWallclockTimeLimit": 36000,
"hardWallclockTimeLimit": "10:00:00",
"jobName": job_name
}

Expand All @@ -54,7 +54,7 @@ def test_slurm_config_native_spec_transforms_correctly(job_template):
'''

job_name = "TEST"
time = "1:00"
time = "01:00"
account = "TEST"
cpus_per_task = 5
slurm_config = SlurmConfig(job_name=job_name,
Expand All @@ -65,3 +65,29 @@ def test_slurm_config_native_spec_transforms_correctly(job_template):
jt = slurm_config.get_drmaa_config(job_template)
for spec in ['account=TEST', 'cpus-per-task=5']:
assert spec in jt.nativeSpecification


def test_invalid_timestr_fails():
job_name = "TEST"
time = "FAILURE"
account = "TEST"
cpus_per_task = 10

with pytest.raises(ValueError):
SlurmConfig(job_name=job_name,
time=time,
account=account,
cpus_per_task=cpus_per_task)


def test_timestr_not_string_fails():
job_name = "TEST"
time = 10
account = "TEST"
cpus_per_task = 10

with pytest.raises(TypeError):
SlurmConfig(job_name=job_name,
time=time,
account=account,
cpus_per_task=cpus_per_task)