-
Notifications
You must be signed in to change notification settings - Fork 7
Improving Code Quality of the Project #20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
|
|
||
| result = self.interpolator(points[not a]) | ||
|
|
||
| condition = ((points[:, 0] < 0) + (points[:, 0] >= self.data.dwi.shape[0]) + # OR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a comment explaining the condition, i.e. say that it is checking whether we are within the domain of our dwi data
| out_memmap[idx:(idx + lengths[i])] = out.numpy() | ||
| idx = idx + lengths[i] | ||
| assert len(self) == len(lengths) | ||
| for i, ((inp, out), length) in enumerate(zip(self, lengths)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add a comment that explains what you are doing here
| def __len__(self): | ||
| return len(self.sl_lengths) | ||
|
|
||
| def __getitem__(self, index): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add documentation to this function, i.e. what are you returning, dimensions, etc.
| self.feature_shapes = (torch.prod(input_shape).item(), torch.prod(output_shape).item()) | ||
| return self.feature_shapes | ||
|
|
||
| def cuda(self, device=None, non_blocking=False, memory_format=torch.preserve_format): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
explain why you are implementing the cuda function here + when to use it :)
| self.device = dwi.device | ||
| return self | ||
|
|
||
| def cpu(self, memory_format=torch.preserve_format): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
explain why you are using a cpu function + when to use it
| import dfibert.envs.tractography as RLTe | ||
|
|
||
|
|
||
| def train(path, max_steps=3000000, replay_memory_size=20000, eps_annealing_steps=100000, agent_history_length=1, evaluate_every=20000, eval_runs=5, network_update_every=10000, max_episode_length=200, learning_rate=0.0000625): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
documentation missing but would be really helpful. @phanfeld might probably provide more insights in the trainign algorithm etc
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
| print("Init environment..") | ||
| env = RLTe.RLtractEnvironment(device = 'cpu') | ||
| env = RLTe.EnvTractography(device = 'cpu') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you actually set the device to cpu?
This reverts commit 7c097f8.
Went through the whole codebase and fixed sensible code quality issues highlighted by pylint.
Closes #15.