Skip to content

Allow flexible returns from a forward function.#940

Open
maxwest-uw wants to merge 22 commits into
mainfrom
issue/858
Open

Allow flexible returns from a forward function.#940
maxwest-uw wants to merge 22 commits into
mainfrom
issue/858

Conversation

@maxwest-uw

Copy link
Copy Markdown
Collaborator

Change Description

resolves #858

hyrax users can now return a dict of torch.Tensors from a model's forward function.

Solution Description

  • change is backwards compatible (i.e. user can return either a dict or a tensor from the forward function and everything should still work)
  • in cases where the forward function just returns a tensor, the data will eventually be converted into a dict with the format { "data": <tensor> }, which is how the data will eventually be saved to lance db anyway.
  • ResultDataset should now always return a dict with the data for a give index or set of indices
  • in cases throughout the data processing pipeline where we assume that a result is going to be a tensor, we have added a new function ResultDataset.get_combined_tensor, which for a given index or list of indices will return a concatenated and flattened tensor of all the available data keys.

Code Quality

  • I have read the Contribution Guide and agree to the Code of Conduct
  • My code follows the code style of this project
  • My code builds (or compiles) cleanly without any errors or warnings
  • My code contains relevant comments and necessary documentation

@review-notebook-app

Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@codecov

codecov Bot commented Jun 8, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 87.67123% with 9 lines in your changes missing coverage. Please review.
✅ Project coverage is 63.06%. Comparing base (3c2fdd5) to head (6f55807).

Files with missing lines Patch % Lines
src/hyrax/datasets/result_dataset.py 93.33% 4 Missing ⚠️
src/hyrax/pytorch_ignite.py 60.00% 4 Missing ⚠️
src/hyrax/verbs/visualize.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #940      +/-   ##
==========================================
+ Coverage   62.94%   63.06%   +0.11%     
==========================================
  Files          71       71              
  Lines        7357     7399      +42     
==========================================
+ Hits         4631     4666      +35     
- Misses       2726     2733       +7     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@maxwest-uw maxwest-uw requested a review from drewoldag June 8, 2026 20:34

@drewoldag drewoldag left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking like it's going in the right direction. I've got some comments. Might be easier to discuss these in person tomorrow if you've got the time.

# Return single tensor or array of tensors
return tensors[0] if is_single else tensors

def get_combined_tensor(self, idx: int):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm unsure about this approach. What I had in mind was something more dynamic. Similar to what we do in teh CSVDataset.

https://github.com/lincc-frameworks/hyrax/blob/main/src/hyrax/datasets/hyrax_csv_dataset.py#L51-L63

With a little bit of introspection at runtime, we can check the schema, and then dynamically build getter methods for each of the different columns that have been saved.

And at that point, we would really have to put the responsibility on the user to do things like select the result dataset field that is going to be operated on. e.g. UMAP reduced, etc...

I think that this might have follow on implications for many other parts of this PR, so perhaps we can spend some time chatting about it tomorrow?

Comment thread src/hyrax/verbs/umap.py

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to need to work my way though this a bit more carefully. Ideally we would be a in a situation where backward compatibility would be maintained. I'm a little concerned about that here.

Did you by any chance have an opportunity to do some backward compatability testing?

import lancedb
import numpy as np
import pyarrow as pa
import torch

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little nervous about this import. It's been nice being able to keep pytorch out of our datasets. Is this just here for type checking?

Comment thread src/hyrax/pytorch_ignite.py Outdated
Comment thread src/hyrax/pytorch_ignite.py Outdated
Comment thread scripts/convert_results.py Outdated
Co-authored-by: Drew Oldag <47493171+drewoldag@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Allow more flexible returns from forward

2 participants