Allow flexible returns from a forward function.#940
Conversation
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #940 +/- ##
==========================================
+ Coverage 62.94% 63.06% +0.11%
==========================================
Files 71 71
Lines 7357 7399 +42
==========================================
+ Hits 4631 4666 +35
- Misses 2726 2733 +7 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
drewoldag
left a comment
There was a problem hiding this comment.
This is looking like it's going in the right direction. I've got some comments. Might be easier to discuss these in person tomorrow if you've got the time.
| # Return single tensor or array of tensors | ||
| return tensors[0] if is_single else tensors | ||
|
|
||
| def get_combined_tensor(self, idx: int): |
There was a problem hiding this comment.
I'm unsure about this approach. What I had in mind was something more dynamic. Similar to what we do in teh CSVDataset.
https://github.com/lincc-frameworks/hyrax/blob/main/src/hyrax/datasets/hyrax_csv_dataset.py#L51-L63
With a little bit of introspection at runtime, we can check the schema, and then dynamically build getter methods for each of the different columns that have been saved.
And at that point, we would really have to put the responsibility on the user to do things like select the result dataset field that is going to be operated on. e.g. UMAP reduced, etc...
I think that this might have follow on implications for many other parts of this PR, so perhaps we can spend some time chatting about it tomorrow?
There was a problem hiding this comment.
I'm going to need to work my way though this a bit more carefully. Ideally we would be a in a situation where backward compatibility would be maintained. I'm a little concerned about that here.
Did you by any chance have an opportunity to do some backward compatability testing?
| import lancedb | ||
| import numpy as np | ||
| import pyarrow as pa | ||
| import torch |
There was a problem hiding this comment.
I'm a little nervous about this import. It's been nice being able to keep pytorch out of our datasets. Is this just here for type checking?
Co-authored-by: Drew Oldag <47493171+drewoldag@users.noreply.github.com>
Change Description
resolves #858
hyraxusers can now return a dict oftorch.Tensors from a model'sforwardfunction.Solution Description
dictwith the format{ "data": <tensor> }, which is how the data will eventually be saved to lance db anyway.ResultDatasetshould now always return a dict with the data for a give index or set of indicesResultDataset.get_combined_tensor, which for a given index or list of indices will return a concatenated and flattened tensor of all the available data keys.Code Quality