Skip to content

Commit 805df30

Browse files
authored
Reduce the diff with the pytorch:main branch (#3)
1 parent d0d23bd commit 805df30

5 files changed

Lines changed: 22 additions & 33 deletions

File tree

src/torchcodec/_core/AVIOTensorContext.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,8 @@ AVIOToTensorContext::AVIOToTensorContext()
123123
}
124124

125125
torch::Tensor AVIOToTensorContext::getOutputTensor() {
126-
throw std::runtime_error(
127-
"AVIOToTensorContext::getOutputTensor is not implemented yet.");
128-
// return tensorContext_.data.narrow(
129-
// /*dim=*/0, /*start=*/0, /*length=*/tensorContext_.max_pos);
126+
return tensorContext_.data.narrow(
127+
/*dim=*/0, /*start=*/0, /*length=*/tensorContext_.max_pos);
130128
}
131129

132130
} // namespace facebook::torchcodec

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,15 +1030,15 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio(
10301030
frames.push_back(*lastSamples);
10311031
}
10321032

1033-
// TORCH_CHECK(
1034-
// frames.size() > 0 && firstFramePtsSeconds.has_value(),
1035-
// "No audio frames were decoded. ",
1036-
// "This is probably because start_seconds is too high(",
1037-
// startSeconds,
1038-
// "),",
1039-
// "or because stop_seconds(",
1040-
// stopSecondsOptional,
1041-
// ") is too low.");
1033+
TORCH_CHECK(
1034+
frames.size() > 0 && firstFramePtsSeconds.has_value(),
1035+
"No audio frames were decoded. ",
1036+
"This is probably because start_seconds is too high(",
1037+
startSeconds,
1038+
"),",
1039+
"or because stop_seconds(",
1040+
stopSecondsOptional,
1041+
") is too low.");
10421042

10431043
return AudioFramesOutput{torch::cat(frames, 1), *firstFramePtsSeconds};
10441044
}
@@ -1419,11 +1419,8 @@ std::optional<torch::Tensor> SingleStreamDecoder::maybeFlushSwrBuffers() {
14191419
auto actualNumRemainingSamples = swr_convert(
14201420
swrContext_.get(), outputBuffers.data(), numRemainingSamples, nullptr, 0);
14211421

1422-
throw std::runtime_error(
1423-
"SingleStreamDecoder::maybeFlushSwrBuffers is not implemented yet.");
1424-
1425-
// return lastSamples.narrow(
1426-
// /*dim=*/1, /*start=*/0, /*length=*/actualNumRemainingSamples);
1422+
return lastSamples.narrow(
1423+
/*dim=*/1, /*start=*/0, /*length=*/actualNumRemainingSamples);
14271424
}
14281425

14291426
// --------------------------------------------------------------------------

src/torchcodec/_core/custom_ops.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
#include <string>
1111
#include "c10/core/SymIntArrayRef.h"
1212
#include "c10/util/Exception.h"
13-
#include "torch/library.h"
1413
#include "src/torchcodec/_core/AVIOFileLikeContext.h"
1514
#include "src/torchcodec/_core/AVIOTensorContext.h"
1615
#include "src/torchcodec/_core/Encoder.h"
1716
#include "src/torchcodec/_core/SingleStreamDecoder.h"
1817
#include "src/torchcodec/_core/ValidationUtils.h"
18+
#include "torch/library.h"
1919

2020
namespace facebook::torchcodec {
2121

@@ -118,7 +118,7 @@ OpsFrameOutput makeOpsFrameOutput(FrameOutput& frame) {
118118
// frame.data,
119119
// torch::tensor(frame.ptsSeconds, torch::dtype(torch::kFloat64)),
120120
// torch::tensor(frame.durationSeconds, torch::dtype(torch::kFloat64)));
121-
return std::make_tuple(
121+
return std::make_tuple(
122122
frame.data,
123123
torch::full({}, frame.ptsSeconds, torch::kFloat64),
124124
torch::full({}, frame.durationSeconds, torch::kFloat64));
@@ -920,15 +920,15 @@ void scan_all_streams_to_update_metadata(at::Tensor& decoder) {
920920
videoDecoder->scanFileAndUpdateMetadataAndIndex();
921921
}
922922

923-
TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
923+
TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) {
924924
m.impl("create_from_file", &create_from_file);
925925
m.impl("create_from_tensor", &create_from_tensor);
926926
m.impl("_create_from_file_like", &_create_from_file_like);
927927
m.impl(
928928
"_get_json_ffmpeg_library_versions", &_get_json_ffmpeg_library_versions);
929-
// }
929+
}
930930

931-
// TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
931+
TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
932932
m.impl("encode_audio_to_file", &encode_audio_to_file);
933933
m.impl("encode_audio_to_tensor", &encode_audio_to_tensor);
934934
m.impl("_encode_audio_to_file_like", &_encode_audio_to_file_like);

src/torchcodec/decoders/_video_decoder.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def __init__(
146146
# if isinstance(device, torch_device):
147147
# device = str(device)
148148
import paddle
149+
149150
if isinstance(device, paddle.base.core.Place):
150151
if device.is_cpu_place():
151152
return "cpu"
@@ -158,12 +159,11 @@ def __init__(
158159

159160
core.add_video_stream(
160161
self._decoder,
161-
num_threads=num_ffmpeg_threads,
162-
dimension_order=dimension_order,
163162
stream_index=stream_index,
163+
dimension_order=dimension_order,
164+
num_threads=num_ffmpeg_threads,
164165
device=device,
165166
device_variant=device_variant,
166-
transform_specs="",
167167
custom_frame_mappings=custom_frame_mappings_data,
168168
)
169169

@@ -265,9 +265,6 @@ def get_frames_at(self, indices: Union[torch.Tensor, list[int]]) -> FrameBatch:
265265
FrameBatch: The frames at the given indices.
266266
"""
267267

268-
if isinstance(indices, list):
269-
indices = torch.tensor(indices, dtype=torch.int64).cpu()
270-
271268
data, pts_seconds, duration_seconds = core.get_frames_at_indices(
272269
self._decoder, frame_indices=indices
273270
)
@@ -347,9 +344,6 @@ def get_frames_played_at(
347344
FrameBatch: The frames that are played at ``seconds``.
348345
"""
349346

350-
if isinstance(seconds, list):
351-
seconds = torch.tensor(seconds, dtype=torch.float32).cpu()
352-
353347
data, pts_seconds, duration_seconds = core.get_frames_by_pts(
354348
self._decoder, timestamps=seconds
355349
)

src/torchcodec/samplers/_index_based.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def _generic_index_based_sampler(
151151

152152
if kind == "random":
153153
clip_start_indices = torch.randint(
154-
sampling_range_start, sampling_range_end, (num_clips,)
154+
low=sampling_range_start, high=sampling_range_end, size=(num_clips,)
155155
)
156156
else:
157157
# Note [num clips larger than sampling range]

0 commit comments

Comments
 (0)