-
Notifications
You must be signed in to change notification settings - Fork 655
Make TFRecord work with dynamic mode #6151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Rostan Tabet <rtabet@nvidia.com>
Signed-off-by: Rostan Tabet <rtabet@nvidia.com>
Signed-off-by: Rostan Tabet <rtabet@nvidia.com>
Signed-off-by: Rostan Tabet <rtabet@nvidia.com>
Greptile SummaryThis PR adds support for dictionary return values in DALI's dynamic mode, specifically fixing TFRecord reader compatibility. The changes propagate dictionary handling through the entire invocation pipeline:
The implementation is clean and follows the existing architecture patterns. The changes are well-contained and don't affect operators that return tuples. Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant TFRecord as TFRecord Reader
participant Operator as Operator._run()
participant Invocation as Invocation._run_impl()
participant OpBuilder as build_call_function()
participant Reader as Reader._samples()/_batches()
User->>TFRecord: next_epoch(batch_size)
TFRecord->>Reader: _samples() or _batches()
Reader->>Operator: _run(ctx, batch_size)
Operator->>Invocation: Create Invocation
Invocation->>Invocation: _run_impl(ctx)
Note over Invocation: Execute operator backend
Invocation->>Invocation: Check result type
alt Result is dict
Invocation->>Invocation: Convert to tuple(r.values())
else Result is tuple/list
Invocation->>Invocation: Keep as tuple
else Result is single value
Invocation->>Invocation: Wrap in tuple
end
Invocation-->>Operator: Return tuple results
Note over Operator: Check _output_names
alt _output_names is set
Operator->>Operator: zip(names, results) → dict
else _output_names is None
Operator->>Operator: Return tuple as-is
end
Operator-->>Reader: dict or tuple
alt Result is dict
Reader->>Reader: Iterate over dict.values()
Reader->>Reader: yield dict(zip(names, tensors))
else Result is tuple
Reader->>Reader: Iterate over tuple
Reader->>Reader: yield tuple(tensors)
end
Reader-->>User: dict[str, Tensor/Batch]
|
Greptile's behavior is changing!From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section. This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR". |
|
!build |
|
CI MESSAGE: [41073266]: BUILD STARTED |
|
CI MESSAGE: [41073266]: BUILD PASSED |
| self.run(self._eval_context) | ||
| return self._results[result_index].layout() | ||
|
|
||
| def __iter__(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity: what is this needed for? You can iterate an object based on __len__/__getitem__.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's true that __getitem__ is enough for an object to be iterable but this adds extra overhead and this can be called in a hot path.
| self._num_outputs = self._operator._infer_num_outputs(*self._inputs, **self._args) | ||
| assert self._num_outputs is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| self._num_outputs = self._operator._infer_num_outputs(*self._inputs, **self._args) | |
| assert self._num_outputs is not None | |
| self._num_outputs = self._operator._infer_num_outputs(*self._inputs, **self._args) | |
| assert self._num_outputs is not None |
No need to run the assert all the time, especially right after we've checked that the requested condition is met.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True. I did it because type checkers infer the return type as Any | int | None and they really hate that __len__ can return none but there's actually a less expensive way to fix it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in 155b391.
| return tuple( | ||
| Batch(invocation_result=invocation[i]) for i in range(len(invocation)) | ||
| ) | ||
| cls = Batch if is_batch else Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd recommend something more precise than cls, which might be confused for the operator class.
| cls = Batch if is_batch else Tensor | |
| ResultType = Batch if is_batch else Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in 64d0c4a
| if is_batch: | ||
|
|
||
| if self._output_names is not None: | ||
| return dict(zip(self._output_names, tuple(out))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it work with non-batch outputs?
Shouldn't you rather convert out to tensors based on is_batch and make a dictionary afterwards?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in ce374de
Signed-off-by: Rostan Tabet <rtabet@nvidia.com>
Signed-off-by: Rostan Tabet <rtabet@nvidia.com>
Signed-off-by: Rostan Tabet <rtabet@nvidia.com>
|
!build |
|
CI MESSAGE: [41161266]: BUILD STARTED |
Signed-off-by: Rostan Tabet <rtabet@nvidia.com>
|
!build |
|
CI MESSAGE: [41163669]: BUILD STARTED |
|
CI MESSAGE: [41163669]: BUILD FAILED |
|
CI MESSAGE: [41163669]: BUILD PASSED |
Category: Bug fix (non-breaking change which fixes an issue)
Description:
TFRecord returns a dictionary, which dynamic mode doesn't handle properly. This PR fixes this issue.
Additional information:
Affected modules and functionalities:
Key points relevant for the review:
Tests:
Checklist
Documentation
DALI team only
Requirements
REQ IDs: N/A
JIRA TASK: N/A