diff --git a/.github/workflows/build_ffmpeg.yaml b/.github/workflows/build_ffmpeg.yaml index 9b9317e3b..5494bf0b1 100644 --- a/.github/workflows/build_ffmpeg.yaml +++ b/.github/workflows/build_ffmpeg.yaml @@ -48,6 +48,33 @@ jobs: mkdir -p "${artifact_dir}" mv ffmpeg.tar.gz "${artifact_dir}/${FFMPEG_VERSION}.tar.gz" + LGPL-Linux-aarch64: + strategy: + fail-fast: false + matrix: + ffmpeg-version: ["4.4.4", "5.1.4", "6.1.1", "7.0.1", "8.0"] + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read + with: + job-name: Build + upload-artifact: ffmpeg-lgpl-linux_aarch64-${{ matrix.ffmpeg-version }} + repository: meta-pytorch/torchcodec + runner: linux.arm64.2xlarge + docker-image: pytorch/manylinux2_28_aarch64-builder:cpu-aarch64 + script: | + export FFMPEG_VERSION="${{ matrix.ffmpeg-version }}" + export FFMPEG_ROOT="${PWD}/ffmpeg" + + packaging/build_ffmpeg.sh + + tar -cf ffmpeg.tar.gz ffmpeg/include ffmpeg/lib + + artifact_dir="${RUNNER_ARTIFACT_DIR}/$(date +%Y-%m-%d)/linux_aarch64" + mkdir -p "${artifact_dir}" + mv ffmpeg.tar.gz "${artifact_dir}/${FFMPEG_VERSION}.tar.gz" + LGPL-macOS: strategy: fail-fast: false diff --git a/.github/workflows/cpp_tests.yaml b/.github/workflows/cpp_tests.yaml index e08d90754..9ea4f0591 100644 --- a/.github/workflows/cpp_tests.yaml +++ b/.github/workflows/cpp_tests.yaml @@ -22,7 +22,7 @@ jobs: ffmpeg-version-for-tests: ['4.4.2', '5.1.2', '6.1.1', '7.0.1'] steps: - name: Check out repo - uses: actions/checkout@v3 + uses: actions/checkout@v6 - name: Setup conda env uses: conda-incubator/setup-miniconda@v3 with: @@ -37,8 +37,7 @@ jobs: - name: Update pip run: python -m pip install --upgrade pip - name: Install torch dependencies - run: | - python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu + run: bash packaging/install_pytorch.sh cpu "torch" - name: Install ffmpeg, pkg-config and pybind11 run: | conda install "ffmpeg=${{ matrix.ffmpeg-version-for-tests }}" pkg-config pybind11 -c conda-forge diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml deleted file mode 100644 index 0829246e0..000000000 --- a/.github/workflows/docs.yaml +++ /dev/null @@ -1,116 +0,0 @@ -name: Docs - -on: - push: - branches: [ main ] - pull_request: - -permissions: - id-token: write - contents: write - -defaults: - run: - shell: bash -l -eo pipefail {0} - -jobs: - generate-matrix: - uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main - with: - package-type: wheel - os: linux - test-infra-repository: pytorch/test-infra - test-infra-ref: main - with-cpu: disable - with-xpu: disable - with-rocm: disable - with-cuda: enable - build-python-only: "disable" - build: - needs: generate-matrix - strategy: - fail-fast: false - name: Build and Upload wheel - uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main - with: - repository: meta-pytorch/torchcodec - ref: "" - test-infra-repository: pytorch/test-infra - test-infra-ref: main - build-matrix: ${{ needs.generate-matrix.outputs.matrix }} - pre-script: packaging/pre_build_script.sh - post-script: packaging/post_build_script.sh - smoke-test-script: packaging/fake_smoke_test.py - package-name: torchcodec - trigger-event: ${{ github.event_name }} - build-platform: "python-build-package" - build-command: "BUILD_AGAINST_ALL_FFMPEG_FROM_S3=1 ENABLE_CUDA=1 python -m build --wheel -vvv --no-isolation" - - build-docs: - runs-on: linux.4xlarge.nvidia.gpu - strategy: - fail-fast: false - matrix: - # 3.10 corresponds to the minimum python version for which we build - # the wheel unless the label cliflow/binaries/all is present in the - # PR. - python-version: ['3.10'] - cuda-version: ['12.6'] - ffmpeg-version-for-tests: ['7'] - container: - image: "pytorch/manylinux2_28-builder:cuda${{ matrix.cuda-version }}" - options: "--gpus all -e NVIDIA_DRIVER_CAPABILITIES=video,compute,utility" - needs: build - steps: - - name: Setup env vars - run: | - cuda_version_without_periods=$(echo "${{ matrix.cuda-version }}" | sed 's/\.//g') - echo cuda_version_without_periods=${cuda_version_without_periods} >> $GITHUB_ENV - python_version_without_periods=$(echo "${{ matrix.python-version }}" | sed 's/\.//g') - echo python_version_without_periods=${python_version_without_periods} >> $GITHUB_ENV - - uses: actions/download-artifact@v4 - with: - name: meta-pytorch_torchcodec__${{ matrix.python-version }}_cu${{ env.cuda_version_without_periods }}_x86_64 - path: pytorch/torchcodec/dist/ - - name: Setup miniconda using test-infra - uses: pytorch/test-infra/.github/actions/setup-miniconda@main - with: - python-version: ${{ matrix.python-version }} - # We install conda packages at the start because otherwise conda may have conflicts with dependencies. - default-packages: "nvidia/label/cuda-${{ matrix.cuda-version }}.0::libnpp nvidia::cuda-nvrtc=${{ matrix.cuda-version }} nvidia::cuda-toolkit=${{ matrix.cuda-version }} nvidia::cuda-cudart=${{ matrix.cuda-version }} nvidia::cuda-driver-dev=${{ matrix.cuda-version }} conda-forge::ffmpeg=${{ matrix.ffmpeg-version-for-tests }}" - - name: Check env - run: | - ${CONDA_RUN} env - ${CONDA_RUN} conda info - ${CONDA_RUN} nvidia-smi - ${CONDA_RUN} conda list - - name: Assert ffmpeg exists - run: | - ${CONDA_RUN} ffmpeg -buildconf - - name: Update pip - run: ${CONDA_RUN} python -m pip install --upgrade pip - - name: Install PyTorch - run: | - ${CONDA_RUN} python -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu${{ env.cuda_version_without_periods }} - ${CONDA_RUN} python -c 'import torch; print(f"{torch.__version__}"); print(f"{torch.__file__}"); print(f"{torch.cuda.is_available()=}")' - - name: Install torchcodec from the wheel - run: | - wheel_path=`find pytorch/torchcodec/dist -type f -name "*cu${{ env.cuda_version_without_periods }}-cp${{ env.python_version_without_periods }}*.whl"` - echo Installing $wheel_path - ${CONDA_RUN} python -m pip install $wheel_path -vvv - - - name: Check out repo - uses: actions/checkout@v3 - - - name: Install doc dependencies - run: | - cd docs - ${CONDA_RUN} python -m pip install -r requirements.txt - - name: Build docs - run: | - cd docs - ${CONDA_RUN} make html - - uses: actions/upload-artifact@v4 - with: - name: Built-Docs - path: docs/build/html/ diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index c156a833c..84dc126f5 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -22,12 +22,12 @@ jobs: python-version: ['3.12'] steps: - name: Check out repo - uses: actions/checkout@v3 + uses: actions/checkout@v6 - name: Setup conda env - uses: conda-incubator/setup-miniconda@v2 + uses: conda-incubator/setup-miniconda@v3 with: auto-update-conda: true - miniconda-version: "latest" + miniforge-version: latest activate-environment: test python-version: ${{ matrix.python-version }} - name: Update pip @@ -50,19 +50,19 @@ jobs: python-version: ['3.12'] steps: - name: Check out repo - uses: actions/checkout@v3 + uses: actions/checkout@v6 - name: Setup conda env - uses: conda-incubator/setup-miniconda@v2 + uses: conda-incubator/setup-miniconda@v3 with: auto-update-conda: true - miniconda-version: "latest" + miniforge-version: latest activate-environment: test python-version: ${{ matrix.python-version }} - name: Update pip run: python -m pip install --upgrade pip - name: Install dependencies and FFmpeg run: | - python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu + bash packaging/install_pytorch.sh cpu "torch torchvision" conda install "ffmpeg=7.0.1" pkg-config pybind11 -c conda-forge ffmpeg -version - name: Build and install torchcodec diff --git a/.github/workflows/linux_cuda_wheel.yaml b/.github/workflows/linux_cuda_wheel.yaml index 17f18fe8b..d8fd5ade9 100644 --- a/.github/workflows/linux_cuda_wheel.yaml +++ b/.github/workflows/linux_cuda_wheel.yaml @@ -1,4 +1,4 @@ -name: Build and test Linux CUDA wheels +name: Build and test Linux CUDA wheels and docs on: pull_request: @@ -84,10 +84,13 @@ jobs: echo cuda_version_without_periods=${cuda_version_without_periods} >> $GITHUB_ENV python_version_without_periods=$(echo "${{ matrix.python-version }}" | sed 's/\.//g') echo python_version_without_periods=${python_version_without_periods} >> $GITHUB_ENV - - uses: actions/download-artifact@v4 - with: - name: meta-pytorch_torchcodec__${{ matrix.python-version }}_cu${{ env.cuda_version_without_periods }}_x86_64 - path: pytorch/torchcodec/dist/ + + - name: Check out repo + uses: actions/checkout@v6 + + - name: Remove src/ folder + run: bash packaging/remove_src.sh + - name: Setup miniconda using test-infra uses: pytorch/test-infra/.github/actions/setup-miniconda@main with: @@ -95,12 +98,13 @@ jobs: # We install conda packages at the start because otherwise conda may have conflicts with dependencies. # Note: xorg-libxau was addded to fix a problem with ffmpeg 4. We should consider removing it. default-packages: "nvidia/label/cuda-${{ matrix.cuda-version }}.0::libnpp nvidia::cuda-nvrtc=${{ matrix.cuda-version }} nvidia::cuda-toolkit=${{ matrix.cuda-version }} nvidia::cuda-cudart=${{ matrix.cuda-version }} nvidia::cuda-driver-dev=${{ matrix.cuda-version }} conda-forge::ffmpeg=${{ matrix.ffmpeg-version-for-tests }} conda-forge::xorg-libxau" - - name: Check env + - name: Check env, set LD_LIBRARY_PATH run: | ${CONDA_RUN} env ${CONDA_RUN} conda info ${CONDA_RUN} nvidia-smi ${CONDA_RUN} conda list + echo LD_LIBRARY_PATH=$CONDA_PREFIX/lib:/usr/local/cuda/lib64/:${LD_LIBRARY_PATH} >> $GITHUB_ENV - name: Assert ffmpeg exists run: | ${CONDA_RUN} ffmpeg -buildconf @@ -108,39 +112,168 @@ jobs: run: ${CONDA_RUN} python -m pip install --upgrade pip - name: Install PyTorch run: | - ${CONDA_RUN} python -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu${{ env.cuda_version_without_periods }} + ${CONDA_RUN} bash packaging/install_pytorch.sh cu${{ env.cuda_version_without_periods }} "torch torchvision" ${CONDA_RUN} python -c 'import torch; print(f"{torch.__version__}"); print(f"{torch.__file__}"); print(f"{torch.cuda.is_available()=}")' + + - uses: actions/download-artifact@v4 + with: + name: meta-pytorch_torchcodec__${{ matrix.python-version }}_cu${{ env.cuda_version_without_periods }}_x86_64 + path: dist/ + - name: Install torchcodec from the wheel + run: ${CONDA_RUN} bash packaging/install_torchcodec_wheel.sh "*cu${{ env.cuda_version_without_periods }}-cp${{ env.python_version_without_periods }}*.whl" + + - name: Install test dependencies + run: ${CONDA_RUN} bash packaging/install_test_dependencies.sh + - name: Run Python tests run: | - wheel_path=`find pytorch/torchcodec/dist -type f -name "*cu${{ env.cuda_version_without_periods }}-cp${{ env.python_version_without_periods }}*.whl"` - echo Installing $wheel_path - ${CONDA_RUN} python -m pip install $wheel_path -vvv + ${CONDA_RUN} FAIL_WITHOUT_CUDA=1 pytest --override-ini="addopts=-v" test --tb=short + - name: Run Python benchmark + run: | + ${CONDA_RUN} time python benchmarks/decoders/gpu_benchmark.py --devices=cuda:0,cpu --resize_devices=none + + build-docs: + runs-on: linux.g5.4xlarge.nvidia.gpu + env: + PYTHON_VERSION: '3.10' + CUDA_VERSION: '12.6' + FFMPEG_VERSION: '7' + container: + image: "pytorch/manylinux2_28-builder:cuda12.6" # must be same as env!! + options: "--gpus all -e NVIDIA_DRIVER_CAPABILITIES=video,compute,utility" + needs: build + steps: + - name: Setup env vars + run: | + cuda_version_without_periods=$(echo "${{ env.CUDA_VERSION }}" | sed 's/\.//g') + echo cuda_version_without_periods=${cuda_version_without_periods} >> $GITHUB_ENV + python_version_without_periods=$(echo "${{ env.PYTHON_VERSION }}" | sed 's/\.//g') + echo python_version_without_periods=${python_version_without_periods} >> $GITHUB_ENV - name: Check out repo - uses: actions/checkout@v3 + uses: actions/checkout@v6 - - name: Install test dependencies + - name: Remove src/ folder + run: bash packaging/remove_src.sh + + - name: Setup miniconda using test-infra + uses: pytorch/test-infra/.github/actions/setup-miniconda@main + with: + python-version: ${{ env.PYTHON_VERSION }} + # We install conda packages at the start because otherwise conda may have conflicts with dependencies. + default-packages: "nvidia/label/cuda-${{ env.CUDA_VERSION }}.0::libnpp nvidia::cuda-nvrtc=${{ env.CUDA_VERSION }} nvidia::cuda-toolkit=${{ env.CUDA_VERSION }} nvidia::cuda-cudart=${{ env.CUDA_VERSION }} nvidia::cuda-driver-dev=${{ env.CUDA_VERSION }} conda-forge::ffmpeg=${{ env.FFMPEG_VERSION }}" + - name: Check env, set LD_LIBRARY_PATH run: | - # Ideally we would find a way to get those dependencies from pyproject.toml - ${CONDA_RUN} python -m pip install numpy pytest pillow + ${CONDA_RUN} env + ${CONDA_RUN} conda info + ${CONDA_RUN} nvidia-smi + ${CONDA_RUN} conda list + echo LD_LIBRARY_PATH=$CONDA_PREFIX/lib:/usr/local/cuda/lib64/:${LD_LIBRARY_PATH} >> $GITHUB_ENV + - name: Assert ffmpeg exists + run: | + ${CONDA_RUN} ffmpeg -buildconf + - name: Update pip + run: ${CONDA_RUN} python -m pip install --upgrade pip + - name: Install PyTorch + run: | + ${CONDA_RUN} bash packaging/install_pytorch.sh cu${{ env.cuda_version_without_periods }} "torch torchvision" + ${CONDA_RUN} python -c 'import torch; print(f"{torch.__version__}"); print(f"{torch.__file__}"); print(f"{torch.cuda.is_available()=}")' - - name: Delete the src/ folder just for fun + - uses: actions/download-artifact@v4 + with: + name: meta-pytorch_torchcodec__${{ env.PYTHON_VERSION }}_cu${{ env.cuda_version_without_periods }}_x86_64 + path: dist/ + + - name: Install torchcodec from the wheel + run: ${CONDA_RUN} bash packaging/install_torchcodec_wheel.sh "*cu${{ env.cuda_version_without_periods }}-cp${{ env.python_version_without_periods }}*.whl" + + - name: Install doc dependencies run: | - # The only reason we checked-out the repo is to get access to the - # tests. We don't care about the rest. Out of precaution, we delete - # the src/ folder to be extra sure that we're running the code from - # the installed wheel rather than from the source. - # This is just to be extra cautious and very overkill because a) - # there's no way the `torchcodec` package from src/ can be found from - # the PythonPath: the main point of `src/` is precisely to protect - # against that and b) if we ever were to execute code from - # `src/torchcodec`, it would fail loudly because the built .so files - # aren't present there. - rm -r src/ - ls - - name: Run Python tests + cd docs + ${CONDA_RUN} python -m pip install -r requirements.txt + - name: Build docs run: | - ${CONDA_RUN} FAIL_WITHOUT_CUDA=1 pytest --override-ini="addopts=-v" test --tb=short - - name: Run Python benchmark + cd docs + ${CONDA_RUN} make html + - uses: actions/upload-artifact@v4 + with: + name: Built-Docs + path: docs/build/html/ + + doc-preview: + runs-on: [self-hosted, linux.2xlarge] + needs: build-docs + if: github.repository == 'meta-pytorch/torchcodec' && github.event_name == 'pull_request' + steps: + - uses: actions/download-artifact@v4 + with: + name: Built-Docs + path: docs + + # Update HTML to add the no-index tag so that search engines do not index these ephemeral docs + - name: Add no-index tag run: | - ${CONDA_RUN} time python benchmarks/decoders/gpu_benchmark.py --devices=cuda:0,cpu --resize_devices=none + find docs -name "*.html" -print0 | xargs -0 sed -i '//a \ \ '; + + - name: Upload docs preview + uses: seemethere/upload-artifact-s3@v5 + with: + retention-days: 14 + s3-bucket: doc-previews + if-no-files-found: error + path: docs + s3-prefix: meta-pytorch/torchcodec/${{ github.event.pull_request.number }} + + upload-docs: + # This job uploads built docs: + # - to the `main` folder of the gh-pages branch (https://meta-pytorch.org/torchcodec/main) on every commit to the `main` branch + # - to the (e.g.) `0.10` folder in the gh-pages branch whenever a corresponding tag is pushed, like `v0.10.0` (https://meta-pytorch.org/torchcodec/0.10). + + needs: build-docs + if: github.repository == 'meta-pytorch/torchcodec' && github.event_name == 'push' && + ((github.ref_type == 'branch' && github.ref_name == 'main') || github.ref_type == 'tag') + runs-on: ubuntu-latest + steps: + - name: Check out gh-pages branch + uses: actions/checkout@v6 + with: + ref: gh-pages + + - uses: actions/download-artifact@v4 + with: + name: Built-Docs + path: docs-artifact/ + + - name: Update docs and push + run: | + set -euo pipefail + + REF_TYPE=${{ github.ref_type }} + REF_NAME=${{ github.ref_name }} + + if [[ "${REF_TYPE}" == branch ]]; then + TARGET_FOLDER="${REF_NAME}" + elif [[ "${REF_TYPE}" == tag ]]; then + case "${REF_NAME}" in + *-rc*) + echo "Aborting upload since this is an RC tag: ${REF_NAME}" + exit 0 + ;; + *) + # Strip the leading "v" as well as the trailing patch version. For example: + # 'v0.10.2' -> '0.10' + TARGET_FOLDER=$(echo "${REF_NAME}" | sed 's/v\([0-9]\+\)\.\([0-9]\+\)\.[0-9]\+/\1.\2/') + ;; + esac + fi + echo "Target Folder: ${TARGET_FOLDER}" + + mkdir -p "${TARGET_FOLDER}" + rm -rf "${TARGET_FOLDER}"/* + cp -r docs-artifact/* "${TARGET_FOLDER}/" + git add "${TARGET_FOLDER}" + + git config user.name 'pytorchbot' + git config user.email 'soumith+bot@pytorch.org' + git commit -m "auto-generating sphinx docs for ${TARGET_FOLDER}" || echo "No changes to commit" + git push diff --git a/.github/workflows/linux_wheel.yaml b/.github/workflows/linux_wheel.yaml index cccbedc25..14d702a91 100644 --- a/.github/workflows/linux_wheel.yaml +++ b/.github/workflows/linux_wheel.yaml @@ -66,10 +66,12 @@ jobs: ffmpeg-version-for-tests: ['4.4.2', '5.1.2', '6.1.1', '7.0.1', '8.0'] needs: build steps: - - uses: actions/download-artifact@v4 - with: - name: meta-pytorch_torchcodec__${{ matrix.python-version }}_cpu_x86_64 - path: pytorch/torchcodec/dist/ + - name: Check out repo + uses: actions/checkout@v6 + + - name: Remove src/ folder + run: bash packaging/remove_src.sh + - name: Setup conda env uses: conda-incubator/setup-miniconda@v3 with: @@ -81,50 +83,82 @@ jobs: miniforge-version: latest activate-environment: test python-version: ${{ matrix.python-version }} + - name: Update pip run: python -m pip install --upgrade pip + - name: Install PyTorch - run: | - python -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu + run: bash packaging/install_pytorch.sh cpu "torch torchvision" + + - uses: actions/download-artifact@v4 + with: + name: meta-pytorch_torchcodec__${{ matrix.python-version }}_cpu_x86_64 + path: dist/ + - name: Install torchcodec from the wheel + run: bash packaging/install_torchcodec_wheel.sh + + - name: Install ffmpeg, post build + run: bash packaging/install_ffmpeg.sh ${{ matrix.ffmpeg-version-for-tests }} + + - name: Install test dependencies + run: bash packaging/install_test_dependencies.sh + + - name: Run Python tests run: | - wheel_path=`find pytorch/torchcodec/dist -type f -name "*.whl"` - echo Installing $wheel_path - python -m pip install $wheel_path -vvv + pytest --override-ini="addopts=-v" test + install-and-test-third-party-interface: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ['3.10'] + ffmpeg-version-for-tests: ['8.0'] + needs: build + steps: - name: Check out repo uses: actions/checkout@v3 + + - name: Remove src/ folder + run: bash packaging/remove_src.sh + + - name: Setup conda env + uses: conda-incubator/setup-miniconda@v3 + with: + auto-update-conda: true + # Using miniforge instead of miniconda ensures that the default + # conda channel is conda-forge instead of main/default. This ensures + # ABI consistency between dependencies: + # https://conda-forge.org/docs/user/transitioning_from_defaults/ + miniforge-version: latest + activate-environment: test + python-version: ${{ matrix.python-version }} + + - name: Update pip + run: python -m pip install --upgrade pip + + - name: Install PyTorch + run: bash packaging/install_pytorch.sh cpu "torch torchvision" + + - uses: actions/download-artifact@v4 + with: + name: meta-pytorch_torchcodec__${{ matrix.python-version }}_cpu_x86_64 + path: dist/ + + - name: Install torchcodec from the wheel + run: bash packaging/install_torchcodec_wheel.sh + - name: Install ffmpeg, post build - run: | - # Ideally we would have checked for that before installing the wheel, - # but we need to checkout the repo to access this file, and we don't - # want to checkout the repo before installing the wheel to avoid any - # side-effect. It's OK. - source packaging/helpers.sh - assert_ffmpeg_not_installed + run: bash packaging/install_ffmpeg.sh ${{ matrix.ffmpeg-version-for-tests }} - conda install "ffmpeg=${{ matrix.ffmpeg-version-for-tests }}" -c conda-forge - ffmpeg -version + - name: Install pkg-config + run: | + conda install pkg-config -c conda-forge - name: Install test dependencies - run: | - # Ideally we would find a way to get those dependencies from pyproject.toml - python -m pip install numpy pytest pillow + run: bash packaging/install_test_dependencies.sh - - name: Delete the src/ folder just for fun - run: | - # The only reason we checked-out the repo is to get access to the - # tests. We don't care about the rest. Out of precaution, we delete - # the src/ folder to be extra sure that we're running the code from - # the installed wheel rather than from the source. - # This is just to be extra cautious and very overkill because a) - # there's no way the `torchcodec` package from src/ can be found from - # the PythonPath: the main point of `src/` is precisely to protect - # against that and b) if we ever were to execute code from - # `src/torchcodec`, it would fail loudly because the built .so files - # aren't present there. - rm -r src/ - ls - name: Run Python tests run: | - pytest --override-ini="addopts=-v" test + pytest --override-ini="addopts=-v" test/third-party-interface diff --git a/.github/workflows/macos_wheel.yaml b/.github/workflows/macos_wheel.yaml index ead45784d..b183b80cd 100644 --- a/.github/workflows/macos_wheel.yaml +++ b/.github/workflows/macos_wheel.yaml @@ -68,58 +68,43 @@ jobs: ffmpeg-version-for-tests: ['4.4.2', '5.1.2', '6.1.1', '7.0.1', '8.0'] needs: build steps: - - name: Download wheel - uses: actions/download-artifact@v4 - with: - name: meta-pytorch_torchcodec__${{ matrix.python-version }}_cpu_ - path: pytorch/torchcodec/dist/ + - name: Check out torchcodec repo + uses: actions/checkout@v6 + + - name: Remove src/ folder + run: bash packaging/remove_src.sh - name: Setup conda env uses: conda-incubator/setup-miniconda@v3 with: auto-update-conda: true - miniconda-version: "latest" + # Using miniforge instead of miniconda ensures that the default + # conda channel is conda-forge instead of main/default. This ensures + # ABI consistency between dependencies: + # https://conda-forge.org/docs/user/transitioning_from_defaults/ + miniforge-version: latest activate-environment: test python-version: ${{ matrix.python-version }} - name: Update pip run: python -m pip install --upgrade pip - name: Install PyTorch - run: | - python -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu + run: bash packaging/install_pytorch.sh cpu "torch torchvision" - - name: Install torchcodec from the wheel - run: | - wheel_path=`find pytorch/torchcodec/dist -type f -name "*.whl"` - echo Installing $wheel_path - python -m pip install $wheel_path -vvv + - name: Download wheel + uses: actions/download-artifact@v4 + with: + name: meta-pytorch_torchcodec__${{ matrix.python-version }}_cpu_ + path: dist/ - - name: Check out torchcodec repo - uses: actions/checkout@v3 + - name: Install torchcodec from the wheel + run: bash packaging/install_torchcodec_wheel.sh - name: Install ffmpeg - run: | - conda install "ffmpeg=${{ matrix.ffmpeg-version-for-tests }}" -c conda-forge - ffmpeg -version + run: bash packaging/install_ffmpeg.sh ${{ matrix.ffmpeg-version-for-tests }} - name: Install test dependencies - run: | - python -m pip install numpy pytest pillow - - - name: Delete the src/ folder just for fun - run: | - # The only reason we checked-out the repo is to get access to the - # tests. We don't care about the rest. Out of precaution, we delete - # the src/ folder to be extra sure that we're running the code from - # the installed wheel rather than from the source. - # This is just to be extra cautious and very overkill because a) - # there's no way the `torchcodec` package from src/ can be found from - # the PythonPath: the main point of `src/` is precisely to protect - # against that and b) if we ever were to execute code from - # `src/torchcodec`, it would fail loudly because the built .so files - # aren't present there. - rm -r src/ - ls -lh + run: bash packaging/install_test_dependencies.sh - name: Run Python tests run: | diff --git a/.github/workflows/paddle_wheel.yaml b/.github/workflows/paddle_wheel.yaml index 4976e20c9..3ce2964c6 100644 --- a/.github/workflows/paddle_wheel.yaml +++ b/.github/workflows/paddle_wheel.yaml @@ -20,15 +20,61 @@ defaults: run: shell: bash -l -eo pipefail {0} +env: + PADDLECODEC_TEST_VIDEO_URL: https://paddlenlp.bj.bcebos.com/datasets/paddlemix/demo_video/example_video.mp4 + PADDLECODEC_TEST_VIDEO_CACHE_KEY: paddlecodec-test-video-v1-example-video + PADDLECODEC_TEST_VIDEO_PATH: .github/test-assets/example_video.mp4 + jobs: - build-paddlecodec-wheel: + prepare-test-video: runs-on: ubuntu-latest + name: Prepare cached Paddle test video + steps: + - name: Restore cached test video + id: cache-test-video + uses: actions/cache@v4 + with: + path: ${{ env.PADDLECODEC_TEST_VIDEO_PATH }} + key: ${{ env.PADDLECODEC_TEST_VIDEO_CACHE_KEY }} + + - name: Download test video + if: steps.cache-test-video.outputs.cache-hit != 'true' + run: | + mkdir -p "$(dirname "${PADDLECODEC_TEST_VIDEO_PATH}")" + curl --fail --location --retry 5 --retry-all-errors \ + --output "${PADDLECODEC_TEST_VIDEO_PATH}" \ + "${PADDLECODEC_TEST_VIDEO_URL}" + + - name: Upload cached test video artifact + uses: actions/upload-artifact@v5 + with: + name: paddlecodec-test-video + path: ${{ env.PADDLECODEC_TEST_VIDEO_PATH }} + if-no-files-found: error + + build-paddlecodec-wheel: + name: Build and upload Paddle wheel (${{ matrix.arch-name }}, py${{ matrix.python-version }}) + runs-on: ${{ matrix.runner }} container: - image: pytorch/manylinux2_28-builder:cpu + image: ${{ matrix.container-image }} strategy: fail-fast: false matrix: python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + arch: ["x86_64", "arm64"] + include: + - arch: x86_64 + arch-name: x86_64 + runner: ubuntu-latest + container-image: pytorch/manylinux2_28-builder:cpu + artifact-prefix: paddlecodec-wheel-linux + wheel-platform: manylinux_2_28_x86_64 + - arch: arm64 + arch-name: arm64 + runner: ubuntu-24.04-arm + container-image: pytorch/manylinux2_28_aarch64-builder:cpu-aarch64 + artifact-prefix: paddlecodec-wheel-linux-arm64 + wheel-platform: manylinux_2_28_aarch64 permissions: id-token: write contents: read @@ -59,8 +105,8 @@ jobs: - name: Build wheel run: | - # Use pre-built FFmpeg from PyTorch S3 export BUILD_AGAINST_ALL_FFMPEG_FROM_S3=1 + export I_CONFIRM_THIS_IS_NOT_A_LICENSE_VIOLATION=1 export TORCHCODEC_CMAKE_BUILD_DIR=$(pwd)/build_cmake python -m build --wheel -vvv --no-isolation @@ -68,29 +114,18 @@ jobs: run: | pip install auditwheel - # 1. Extract internal libraries from the wheel to a temporary directory - # This allows auditwheel to find them when checking dependencies mkdir -p temp_libs unzip -j dist/*.whl "torchcodec/*.so" -d temp_libs || true - # 2. Prepare LD_LIBRARY_PATH - # FFmpeg libraries - FFMPEG_LIB_PATHS=$(find $(pwd)/build_cmake/_deps -type d -name "lib" | tr '\n' ':') - # PaddlePaddle libraries + FFMPEG_LIB_PATHS=$(find "$(pwd)/build_cmake/_deps" -type d -name "lib" -print | paste -sd: -) PADDLE_PATH=$(python -c "import paddle; print(paddle.__path__[0])") PADDLE_LIB_PATHS="$PADDLE_PATH/base:$PADDLE_PATH/libs" - # Wheel internal libraries INTERNAL_LIB_PATH=$(pwd)/temp_libs - export LD_LIBRARY_PATH=${FFMPEG_LIB_PATHS}${PADDLE_LIB_PATHS}:${INTERNAL_LIB_PATH}:${LD_LIBRARY_PATH} + export LD_LIBRARY_PATH=${FFMPEG_LIB_PATHS}:${PADDLE_LIB_PATHS}:${INTERNAL_LIB_PATH}:${LD_LIBRARY_PATH} - # 3. Repair wheel with auditwheel - # We exclude all external libraries because we want to rely on system libraries (like FFmpeg) - # or libraries provided by other packages (like PaddlePaddle). - # auditwheel 6.1.0+ supports wildcards in --exclude. - auditwheel repair dist/*.whl --plat manylinux_2_28_x86_64 -w wheelhouse/ --exclude "*" + auditwheel repair dist/*.whl --plat ${{ matrix.wheel-platform }} -w wheelhouse/ --exclude "*" - # Cleanup rm -rf temp_libs rm dist/*.whl mv wheelhouse/*.whl dist/ @@ -99,7 +134,7 @@ jobs: - name: Upload wheel artifact uses: actions/upload-artifact@v5 with: - name: paddlecodec-wheel-linux-py${{ matrix.python-version }} + name: ${{ matrix.artifact-prefix }}-py${{ matrix.python-version }} path: dist/*.whl - name: Run post-build script @@ -113,31 +148,49 @@ jobs: unzip -l $wheel_path test-paddlecodec-wheel: - needs: build-paddlecodec-wheel - runs-on: ubuntu-latest + name: Install and test Paddle wheel (${{ matrix.arch-name }}, py${{ matrix.python-version }}, ffmpeg ${{ matrix.ffmpeg-version }}) + needs: [prepare-test-video, build-paddlecodec-wheel] + runs-on: ${{ matrix.runner }} + container: + image: ${{ matrix.container-image }} + env: + PADDLECODEC_TEST_VIDEO: .github/test-assets/example_video.mp4 strategy: fail-fast: false matrix: python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + arch: ["x86_64", "arm64"] # FFmpeg 8.0 depends on libopenvino.so.2520, PaddlePaddle CPU depends on libopenvino.so.2500 # There has some conflict causing test failures, but it works with PaddlePaddle GPU. # We skip FFmpeg 8.0 tests for PaddlePaddle CPU builds for now. ffmpeg-version: ["4.4.2", "5.1.2", "6.1.1", "7.0.1"] + include: + - arch: x86_64 + arch-name: x86_64 + runner: ubuntu-latest + container-image: pytorch/manylinux2_28-builder:cpu + artifact-prefix: paddlecodec-wheel-linux + - arch: arm64 + arch-name: arm64 + runner: ubuntu-24.04-arm + container-image: pytorch/manylinux2_28_aarch64-builder:cpu-aarch64 + artifact-prefix: paddlecodec-wheel-linux-arm64 steps: - name: Checkout repository uses: actions/checkout@v6 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - name: Download wheel artifact uses: actions/download-artifact@v4 with: - name: paddlecodec-wheel-linux-py${{ matrix.python-version }} + name: ${{ matrix.artifact-prefix }}-py${{ matrix.python-version }} path: dist/ + - name: Download cached test video artifact + uses: actions/download-artifact@v4 + with: + name: paddlecodec-test-video + path: .github/test-assets/ + - name: Install FFmpeg via conda uses: conda-incubator/setup-miniconda@v3 with: @@ -167,7 +220,6 @@ jobs: - name: Delete src folder run: | - # Delete src/ to ensure we're testing the installed wheel, not source code rm -rf src/ ls -la @@ -177,18 +229,17 @@ jobs: publish-pypi: runs-on: ubuntu-latest - name: Publish to PyPI + name: Publish Paddle wheels to PyPI if: "startsWith(github.ref, 'refs/tags/')" needs: - test-paddlecodec-wheel permissions: id-token: write - steps: - name: Retrieve release distributions uses: actions/download-artifact@v6 with: - pattern: paddlecodec-wheel-linux-* + pattern: paddlecodec-wheel-linux* path: dist/ merge-multiple: true @@ -197,7 +248,7 @@ jobs: publish-release: runs-on: ubuntu-latest - name: Publish to GitHub + name: Publish Paddle wheels to GitHub if: "startsWith(github.ref, 'refs/tags/')" needs: - test-paddlecodec-wheel @@ -206,7 +257,7 @@ jobs: steps: - uses: actions/download-artifact@v6 with: - pattern: paddlecodec-wheel-linux-* + pattern: paddlecodec-wheel-linux* path: dist/ merge-multiple: true - name: Get tag name diff --git a/.github/workflows/reference_resources.yaml b/.github/workflows/reference_resources.yaml index 8f97378f1..135799731 100644 --- a/.github/workflows/reference_resources.yaml +++ b/.github/workflows/reference_resources.yaml @@ -53,42 +53,41 @@ jobs: fail-fast: false matrix: python-version: ['3.10'] - ffmpeg-version-for-tests: ['4.4.2', '5.1.2', '6.1.1', '7.0.1'] + # Traditionally we generate the resources locally on FFmpeg 4 or 6. + # The exact version shouln't matter as long as the unit tests pass + # across all version for a given generated resource. + ffmpeg-version-for-tests: ['6.1.1'] steps: - - uses: actions/download-artifact@v4 - with: - name: meta-pytorch_torchcodec__${{ matrix.python-version }}_cpu_x86_64 - path: pytorch/torchcodec/dist/ + - name: Check out repo + uses: actions/checkout@v3 + - name: Setup conda env - uses: conda-incubator/setup-miniconda@v2 + uses: conda-incubator/setup-miniconda@v3 with: auto-update-conda: true - miniconda-version: "latest" + miniforge-version: latest activate-environment: test python-version: ${{ matrix.python-version }} - - name: Install ffmpeg - run: | - conda install "ffmpeg=${{ matrix.ffmpeg-version-for-tests }}" -c conda-forge - ffmpeg -version - - name: Update pip run: python -m pip install --upgrade pip - - name: Install generation dependencies - run: | - # Note that we're installing stable - this is for running a script where we're a normal PyTorch - # user, not for building TorhCodec. - python -m pip install torch --index-url https://download.pytorch.org/whl/cpu - python -m pip install numpy pillow pytest + - name: Install PyTorch + run: bash packaging/install_pytorch.sh cpu "torch" + + - uses: actions/download-artifact@v4 + with: + name: meta-pytorch_torchcodec__${{ matrix.python-version }}_cpu_x86_64 + path: dist/ - name: Install torchcodec from the wheel - run: | - wheel_path=`find pytorch/torchcodec/dist -type f -name "*.whl"` - echo Installing $wheel_path - python -m pip install $wheel_path -vvv - - name: Check out repo - uses: actions/checkout@v3 + run: bash packaging/install_torchcodec_wheel.sh + + - name: Install ffmpeg, post build + run: bash packaging/install_ffmpeg.sh ${{ matrix.ffmpeg-version-for-tests }} + + - name: Install test dependencies + run: bash packaging/install_test_dependencies.sh - name: Run generation reference resources run: | diff --git a/.github/workflows/windows_wheel.yaml b/.github/workflows/windows_wheel.yaml index 8a9b5b740..72df7c9c0 100644 --- a/.github/workflows/windows_wheel.yaml +++ b/.github/workflows/windows_wheel.yaml @@ -74,12 +74,14 @@ jobs: ffmpeg-version-for-tests: ['4.4.2', '6.1.1', '7.0.1', '8.0'] needs: build steps: - - uses: actions/download-artifact@v4 - with: - name: meta-pytorch_torchcodec__${{ matrix.python-version }}_cpu_x64 - path: pytorch/torchcodec/dist/ + - name: Check out repo + uses: actions/checkout@v6 + + - name: Remove src/ folder + run: bash packaging/remove_src.sh + - name: Setup conda env - uses: conda-incubator/setup-miniconda@v2 + uses: conda-incubator/setup-miniconda@v3 with: auto-update-conda: true # Using miniforge instead of miniconda ensures that the default @@ -91,48 +93,24 @@ jobs: python-version: ${{ matrix.python-version }} - name: Update pip run: python -m pip install --upgrade pip + - name: Install PyTorch - run: | - python -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu + run: bash packaging/install_pytorch.sh cpu "torch torchvision" + + - uses: actions/download-artifact@v4 + with: + name: meta-pytorch_torchcodec__${{ matrix.python-version }}_cpu_x64 + path: dist/ + - name: Install torchcodec from the wheel - run: | - wheel_path=`find pytorch/torchcodec/dist -type f -name "*.whl"` - echo Installing $wheel_path - python -m pip install $wheel_path -vvv - - name: Check out repo - uses: actions/checkout@v3 + run: bash packaging/install_torchcodec_wheel.sh + - name: Install ffmpeg, post build - run: | - # Ideally we would have checked for that before installing the wheel, - # but we need to checkout the repo to access this file, and we don't - # want to checkout the repo before installing the wheel to avoid any - # side-effect. It's OK. - source packaging/helpers.sh - assert_ffmpeg_not_installed - conda install "ffmpeg=${{ matrix.ffmpeg-version-for-tests }}" -c conda-forge - ffmpeg -version - - name: Test torchcodec import after FFmpeg installation - run: | - echo "Testing torchcodec import after FFmpeg is installed and PATH is updated..." - python -c "import torchcodec; print('TorchCodec import successful!')" + # need -l for conda to be exposed + run: bash -l packaging/install_ffmpeg.sh ${{ matrix.ffmpeg-version-for-tests }} + - name: Install test dependencies - run: | - # Ideally we would find a way to get those dependencies from pyproject.toml - python -m pip install numpy pytest pillow - - name: Delete the src/ folder just for fun - run: | - # The only reason we checked-out the repo is to get access to the - # tests. We don't care about the rest. Out of precaution, we delete - # the src/ folder to be extra sure that we're running the code from - # the installed wheel rather than from the source. - # This is just to be extra cautious and very overkill because a) - # there's no way the `torchcodec` package from src/ can be found from - # the PythonPath: the main point of `src/` is precisely to protect - # against that and b) if we ever were to execute code from - # `src/torchcodec`, it would fail loudly because the built .so files - # aren't present there. - rm -r src/ - ls + run: bash packaging/install_test_dependencies.sh + - name: Run Python tests - run: | - pytest test -vvv + run: pytest --override-ini="addopts=-v" test diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 84e404a3e..775c100d0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,6 +14,14 @@ repos: - id: check-added-large-files args: ['--maxkb=1000'] + - repo: https://github.com/asottile/pyupgrade + rev: v3.21.2 + hooks: + - id: pyupgrade + args: [--py310-plus] + files: ^(test|src)/ + exclude: ^examples/ + - repo: https://github.com/omnilib/ufmt rev: v2.6.0 hooks: diff --git a/CMakeLists.txt b/CMakeLists.txt index 0f2d0de2d..18dc98c01 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,13 @@ cmake_minimum_required(VERSION 3.18) project(TorchCodec) +# Define LINUX platform variable globally +if (UNIX AND NOT APPLE) + set(LINUX TRUE) +else() + set(LINUX FALSE) +endif() + add_subdirectory(src/torchcodec/_core) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6c42e98f2..7a651d4e6 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -30,7 +30,7 @@ Start by installing the **nightly** build of PyTorch following the Then, the easiest way to install the rest of the dependencies is to run: ```bash -conda install cmake pkg-config pybind11 "ffmpeg<8" -c conda-forge +conda install cmake pkg-config pybind11 "ffmpeg" -c conda-forge ``` ### Clone and build @@ -114,6 +114,34 @@ all of them, you can use a regex like Run `make clean` from time to time if you encounter issues. +### Serving docs locally (if building from a GPU env) + +If you're developing locally, you can just open the generated `index.html`file +in your browser. + +If instead you're using a remote machine, you can use a combination of a simple +python HTTP server and port forwarding to serve the docs locally. This allows +you to iterate on the documentation much more quickly than relying on +PR previews. + +To do so, after following the above doc build steps, run the following from +the `docs/build/html` folder: + +``` +python -m http.server 8000 # or any free port +``` + +This will open up a simple HTTP server serving the files in the build directory. +If this is done on a remote machine, you can set up port forwarding from your +local machine to access the server, for example: + +``` +ssh -L 9000:localhost:8000 $REMOTE_DEV_HOST +``` + +Now, you can navigate to `localhost:9000` on your local machine to view the +rendered documentation. + ## License By contributing to TorchCodec, you agree that your contributions will be diff --git a/README.md b/README.md index 5c84184d2..73de0969f 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -[**Installation**](#installing-torchcodec) | [**Simple Example**](#using-torchcodec) | [**Detailed Example**](https://pytorch.org/torchcodec/stable/generated_examples/) | [**Documentation**](https://pytorch.org/torchcodec) | [**Contributing**](CONTRIBUTING.md) | [**License**](#license) +[**Installation**](#installing-torchcodec) | [**Simple Example**](#using-torchcodec) | [**Detailed Example**](https://meta-pytorch.org/torchcodec/stable/generated_examples/) | [**Documentation**](https://meta-pytorch.org/torchcodec) | [**Contributing**](CONTRIBUTING.md) | [**License**](#license) # PaddleCodec @@ -26,9 +26,9 @@ The original README.md content is as follows: --- TorchCodec is a Python library for decoding video and audio data into PyTorch -tensors, on CPU and CUDA GPU. It also supports audio encoding, and video -encoding will come soon! It aims to be fast, easy to use, and well integrated -into the PyTorch ecosystem. If you want to use PyTorch to train ML models on +tensors, on CPU and CUDA GPU. It also supports video and audio encoding on CPU! +It aims to be fast, easy to use, and well integrated +into the PyTorch ecosystem. If you want to use PyTorch to train ML models on videos and audio, TorchCodec is how you turn these into data. We achieve these capabilities through: @@ -46,7 +46,7 @@ We achieve these capabilities through: Here's a condensed summary of what you can do with TorchCodec. For more detailed examples, [check out our -documentation](https://pytorch.org/torchcodec/stable/generated_examples/)! +documentation](https://meta-pytorch.org/torchcodec/stable/generated_examples/)! #### Decoding @@ -130,40 +130,45 @@ ffmpeg -f lavfi -i \ versions, refer to the table below for compatibility between versions of `torch` and `torchcodec`. -2. Install FFmpeg, if it's not already installed. Linux distributions usually - come with FFmpeg pre-installed. TorchCodec supports major FFmpeg versions - in [4, 7] on all platforms, and FFmpeg version 8 is supported on Mac and Linux. +2. Install FFmpeg, if it's not already installed. TorchCodec supports + all major FFmpeg versions in [4, 8]. + Linux distributions usually come with FFmpeg pre-installed. You'll need + FFmpeg that comes with separate shared libraries. This is especially relevant + for Windows users: these are usually called the "shared" releases. If FFmpeg is not already installed, or you need a more recent version, an easy way to install it is to use `conda`: ```bash - conda install "ffmpeg<8" + conda install "ffmpeg" # or - conda install "ffmpeg<8" -c conda-forge + conda install "ffmpeg" -c conda-forge ``` 3. Install TorchCodec: ```bash - pip install torchcodec + pip install torchcodec --index-url=https://download.pytorch.org/whl/cpu ``` The following table indicates the compatibility between versions of `torchcodec`, `torch` and Python. -| `torchcodec` | `torch` | Python | -| ------------------ | ------------------ | ------------------ | -| `main` / `nightly` | `main` / `nightly` | `>=3.10`, `<=3.13` | -| `0.8` | `2.9` | `>=3.10`, `<=3.13` | -| `0.7` | `2.8` | `>=3.9`, `<=3.13` | -| `0.6` | `2.8` | `>=3.9`, `<=3.13` | -| `0.5` | `2.7` | `>=3.9`, `<=3.13` | -| `0.4` | `2.7` | `>=3.9`, `<=3.13` | -| `0.3` | `2.7` | `>=3.9`, `<=3.13` | -| `0.2` | `2.6` | `>=3.9`, `<=3.13` | -| `0.1` | `2.5` | `>=3.9`, `<=3.12` | -| `0.0.3` | `2.4` | `>=3.8`, `<=3.12` | +| `torchcodec` | `torch` | Python | +| ------------------ | ------------------ | ------------------- | +| `main` / `nightly` | `main` / `nightly` | `>=3.10`, `<=3.14` | +| `0.11` | `2.11` | `>=3.10`, `<=3.14` | +| `0.10` | `2.10` | `>=3.10`, `<=3.14` | +| `0.9` | `2.9` | `>=3.10`, `<=3.14` | +| `0.8` | `2.9` | `>=3.10`, `<=3.13` | +| `0.7` | `2.8` | `>=3.9`, `<=3.13` | +| `0.6` | `2.8` | `>=3.9`, `<=3.13` | +| `0.5` | `2.7` | `>=3.9`, `<=3.13` | +| `0.4` | `2.7` | `>=3.9`, `<=3.13` | +| `0.3` | `2.7` | `>=3.9`, `<=3.13` | +| `0.2` | `2.6` | `>=3.9`, `<=3.13` | +| `0.1` | `2.5` | `>=3.9`, `<=3.12` | +| `0.0.3` | `2.4` | `>=3.8`, `<=3.12` | ### Installing CUDA-enabled TorchCodec @@ -172,16 +177,15 @@ format you want. Refer to Nvidia's GPU support matrix for more details [here](https://developer.nvidia.com/video-encode-and-decode-gpu-support-matrix-new). 1. Install FFmpeg with NVDEC support. - TorchCodec with CUDA should work with FFmpeg versions in [4, 7] on all platforms, - and FFmpeg version 8 is supported on Linux. + TorchCodec with CUDA should work with FFmpeg versions in [4, 8]. If FFmpeg is not already installed, or you need a more recent version, an easy way to install it is to use `conda`: ```bash - conda install "ffmpeg<8" + conda install "ffmpeg" # or - conda install "ffmpeg<8" -c conda-forge + conda install "ffmpeg" -c conda-forge ``` After installing FFmpeg make sure it has NVDEC support when you list the supported @@ -208,17 +212,19 @@ format you want. Refer to Nvidia's GPU support matrix for more details 3. Install TorchCodec - Pass in an `--index-url` parameter that corresponds to your CUDA Toolkit - version, for example: + On Linux, `pip install torchcodec` defaults to a CUDA wheel, + matching the default behavior of `pip install torch`. ```bash - # This corresponds to CUDA Toolkit version 12.6. It should be the same one - # you used when you installed PyTorch (If you installed PyTorch with pip). - pip install torchcodec --index-url=https://download.pytorch.org/whl/cu126 + pip install torchcodec ``` + Use `--index-url` to select a different CUDA Toolkit version: - Note that without passing in the `--index-url` parameter, `pip` installs - the CPU-only version of TorchCodec. + ```bash + # This corresponds to CUDA Toolkit version 13.0. It should be the same one + # you used when you installed PyTorch (If you installed PyTorch with pip). + pip install torchcodec --index-url=https://download.pytorch.org/whl/cu130 + ``` #### Windows @@ -242,7 +248,7 @@ The bottom row is [promotional video from NASA](https://download.pytorch.org/tor that has a resolution of 960x540 at 29.7 fps and is 206 seconds long. Both videos were encoded with libx264 and yuv420p pixel format. All decoders, except for TorchVision, used FFmpeg 6.1.2. TorchVision used FFmpeg 4.2.2. -For TorchCodec, the "approx" label means that it was using [approximate mode](https://pytorch.org/torchcodec/stable/generated_examples/approximate_mode.html) +For TorchCodec, the "approx" label means that it was using [approximate mode](https://meta-pytorch.org/torchcodec/stable/generated_examples/decoding/approximate_mode.html) for seeking. ## Contributing diff --git a/benchmarks/decoders/benchmark_decoders_library.py b/benchmarks/decoders/benchmark_decoders_library.py index a975aec7e..57174ab89 100644 --- a/benchmarks/decoders/benchmark_decoders_library.py +++ b/benchmarks/decoders/benchmark_decoders_library.py @@ -14,14 +14,8 @@ import torch import torch.utils.benchmark as benchmark -from torchcodec._core import ( - _add_video_stream, - create_from_file, - get_frames_at_indices, - get_frames_by_pts, - get_next_frame, - seek_to_pts, -) +from torchcodec._core import get_frames_at_indices, get_frames_by_pts, get_next_frame +from torchcodec._core.ops import _add_video_stream, create_from_file, seek_to_pts from torchcodec._frame import FrameBatch from torchcodec.decoders import VideoDecoder, VideoStreamMetadata diff --git a/benchmarks/decoders/benchmark_transforms.py b/benchmarks/decoders/benchmark_transforms.py index 75a49d63b..01222f403 100644 --- a/benchmarks/decoders/benchmark_transforms.py +++ b/benchmarks/decoders/benchmark_transforms.py @@ -5,14 +5,11 @@ import torch from torch import Tensor -from torchcodec._core import add_video_stream, create_from_file, get_frames_by_pts from torchcodec.decoders import VideoDecoder from torchvision.transforms import v2 -DEFAULT_NUM_EXP = 20 - -def bench(f, *args, num_exp=DEFAULT_NUM_EXP, warmup=1) -> Tensor: +def bench(f, *args, num_exp, warmup=1) -> Tensor: for _ in range(warmup): f(*args) @@ -45,37 +42,55 @@ def report_stats(times: Tensor, unit: str = "ms", prefix: str = "") -> float: def torchvision_resize( - path: Path, pts_seconds: list[float], dims: tuple[int, int] -) -> None: - decoder = create_from_file(str(path), seek_mode="approximate") - add_video_stream(decoder) - raw_frames, *_ = get_frames_by_pts(decoder, timestamps=pts_seconds) - return v2.functional.resize(raw_frames, size=dims) + path: Path, pts_seconds: list[float], dims: tuple[int, int], num_threads: int +) -> Tensor: + decoder = VideoDecoder( + path, seek_mode="approximate", num_ffmpeg_threads=num_threads + ) + raw_frames = decoder.get_frames_played_at(pts_seconds) + transformed_frames = v2.Resize(size=dims)(raw_frames.data) + assert len(transformed_frames) == len(pts_seconds) + return transformed_frames def torchvision_crop( - path: Path, pts_seconds: list[float], dims: tuple[int, int], x: int, y: int -) -> None: - decoder = create_from_file(str(path), seek_mode="approximate") - add_video_stream(decoder) - raw_frames, *_ = get_frames_by_pts(decoder, timestamps=pts_seconds) - return v2.functional.crop(raw_frames, top=y, left=x, height=dims[0], width=dims[1]) - - -def decoder_native_resize( - path: Path, pts_seconds: list[float], dims: tuple[int, int] -) -> None: - decoder = create_from_file(str(path), seek_mode="approximate") - add_video_stream(decoder, transform_specs=f"resize, {dims[0]}, {dims[1]}") - return get_frames_by_pts(decoder, timestamps=pts_seconds)[0] - - -def decoder_native_crop( - path: Path, pts_seconds: list[float], dims: tuple[int, int], x: int, y: int -) -> None: - decoder = create_from_file(str(path), seek_mode="approximate") - add_video_stream(decoder, transform_specs=f"crop, {dims[0]}, {dims[1]}, {x}, {y}") - return get_frames_by_pts(decoder, timestamps=pts_seconds)[0] + path: Path, pts_seconds: list[float], dims: tuple[int, int], num_threads: int +) -> Tensor: + decoder = VideoDecoder( + path, seek_mode="approximate", num_ffmpeg_threads=num_threads + ) + raw_frames = decoder.get_frames_played_at(pts_seconds) + transformed_frames = v2.CenterCrop(size=dims)(raw_frames.data) + assert len(transformed_frames) == len(pts_seconds) + return transformed_frames + + +def decoder_resize( + path: Path, pts_seconds: list[float], dims: tuple[int, int], num_threads: int +) -> Tensor: + decoder = VideoDecoder( + path, + transforms=[v2.Resize(size=dims)], + seek_mode="approximate", + num_ffmpeg_threads=num_threads, + ) + transformed_frames = decoder.get_frames_played_at(pts_seconds).data + assert len(transformed_frames) == len(pts_seconds) + return transformed_frames.data + + +def decoder_crop( + path: Path, pts_seconds: list[float], dims: tuple[int, int], num_threads: int +) -> Tensor: + decoder = VideoDecoder( + path, + transforms=[v2.CenterCrop(size=dims)], + seek_mode="approximate", + num_ffmpeg_threads=num_threads, + ) + transformed_frames = decoder.get_frames_played_at(pts_seconds).data + assert len(transformed_frames) == len(pts_seconds) + return transformed_frames def main(): @@ -84,9 +99,27 @@ def main(): parser.add_argument( "--num-exp", type=int, - default=DEFAULT_NUM_EXP, + default=5, help="number of runs to average over", ) + parser.add_argument( + "--num-threads", + type=int, + default=1, + help="number of threads to use; 0 means FFmpeg decides", + ) + parser.add_argument( + "--total-frame-fractions", + nargs="+", + type=float, + default=[0.005, 0.01, 0.05, 0.1], + ) + parser.add_argument( + "--input-dimension-fractions", + nargs="+", + type=float, + default=[0.5, 0.25, 0.125], + ) args = parser.parse_args() path = Path(args.path) @@ -100,10 +133,8 @@ def main(): input_height = metadata.height input_width = metadata.width - fraction_of_total_frames_to_sample = [0.005, 0.01, 0.05, 0.1] - fraction_of_input_dimensions = [0.5, 0.25, 0.125] - for num_fraction in fraction_of_total_frames_to_sample: + for num_fraction in args.total_frame_fractions: num_frames_to_sample = math.ceil(metadata.num_frames * num_fraction) print( f"Sampling {num_fraction * 100}%, {num_frames_to_sample}, of {metadata.num_frames} frames" @@ -112,51 +143,49 @@ def main(): i * duration / num_frames_to_sample for i in range(num_frames_to_sample) ] - for dims_fraction in fraction_of_input_dimensions: + for dims_fraction in args.input_dimension_fractions: dims = (int(input_height * dims_fraction), int(input_width * dims_fraction)) times = bench( - torchvision_resize, path, uniform_timestamps, dims, num_exp=args.num_exp + torchvision_resize, + path, + uniform_timestamps, + dims, + args.num_threads, + num_exp=args.num_exp, ) report_stats(times, prefix=f"torchvision_resize({dims})") times = bench( - decoder_native_resize, + decoder_resize, path, uniform_timestamps, dims, + args.num_threads, num_exp=args.num_exp, ) - report_stats(times, prefix=f"decoder_native_resize({dims})") - print() + report_stats(times, prefix=f"decoder_resize({dims})") - center_x = (input_height - dims[0]) // 2 - center_y = (input_width - dims[1]) // 2 times = bench( torchvision_crop, path, uniform_timestamps, dims, - center_x, - center_y, + args.num_threads, num_exp=args.num_exp, ) - report_stats( - times, prefix=f"torchvision_crop({dims}, {center_x}, {center_y})" - ) + report_stats(times, prefix=f"torchvision_crop({dims})") times = bench( - decoder_native_crop, + decoder_crop, path, uniform_timestamps, dims, - center_x, - center_y, + args.num_threads, num_exp=args.num_exp, ) - report_stats( - times, prefix=f"decoder_native_crop({dims}, {center_x}, {center_y})" - ) + report_stats(times, prefix=f"decoder_crop({dims})") + print() diff --git a/benchmarks/decoders/gpu_benchmark.py b/benchmarks/decoders/gpu_benchmark.py index 4300643dd..638737e88 100644 --- a/benchmarks/decoders/gpu_benchmark.py +++ b/benchmarks/decoders/gpu_benchmark.py @@ -7,8 +7,9 @@ import torch.utils.benchmark as benchmark -import torchcodec import torchvision.transforms.v2.functional as F +from torchcodec._core import get_next_frame +from torchcodec._core.ops import _add_video_stream, create_from_file RESIZED_WIDTH = 256 RESIZED_HEIGHT = 256 @@ -25,7 +26,7 @@ def decode_full_video(video_path, decode_device_string, resize_device_string): # We use the core API instead of SimpleVideoDecoder because the core API # allows us to natively resize as part of the decode step. print(f"{decode_device_string=} {resize_device_string=}") - decoder = torchcodec._core.create_from_file(video_path) + decoder = create_from_file(video_path) num_threads = None if "cuda" in decode_device_string: num_threads = 1 @@ -34,7 +35,7 @@ def decode_full_video(video_path, decode_device_string, resize_device_string): if "native" in resize_device_string: resize_spec = f"resize, {RESIZED_HEIGHT}, {RESIZED_WIDTH}" - torchcodec._core._add_video_stream( + _add_video_stream( decoder, stream_index=-1, device=decode_device_string, @@ -46,7 +47,7 @@ def decode_full_video(video_path, decode_device_string, resize_device_string): frame_count = 0 while True: try: - frame, *_ = torchcodec._core.get_next_frame(decoder) + frame, *_ = get_next_frame(decoder) if resize_device_string != "none" and "native" not in resize_device_string: frame = transfer_and_resize_frame(frame, resize_device_string) diff --git a/benchmarks/decoders/memprofile_decoders.py b/benchmarks/decoders/memprofile_decoders.py index 16bc42dc6..a78eb1263 100644 --- a/benchmarks/decoders/memprofile_decoders.py +++ b/benchmarks/decoders/memprofile_decoders.py @@ -9,7 +9,8 @@ import torch from memory_profiler import profile -from torchcodec._core import add_video_stream, create_from_file, get_next_frame +from torchcodec._core import get_next_frame +from torchcodec._core.ops import add_video_stream, create_from_file torch._dynamo.config.cache_size_limit = 100 torch._dynamo.config.capture_dynamic_output_shape_ops = True diff --git a/benchmarks/encoders/benchmark_encoders.py b/benchmarks/encoders/benchmark_encoders.py new file mode 100644 index 000000000..f59501f80 --- /dev/null +++ b/benchmarks/encoders/benchmark_encoders.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 +import shutil +import subprocess +import tempfile +from argparse import ArgumentParser +from pathlib import Path +from time import perf_counter_ns + +import pynvml +import torch +from torchcodec.decoders import VideoDecoder +from torchcodec.encoders import VideoEncoder + +pynvml.nvmlInit() +handle = pynvml.nvmlDeviceGetHandleByIndex(0) + +FRAME_RATE = 30 +DEFAULT_VIDEO_PATH = "test/resources/nasa_13013.mp4" +# Alternatively, run this command to generate a longer test video: +# ffmpeg -f lavfi -i testsrc2=duration=600:size=1280x720:rate=30 -c:v libx264 -pix_fmt yuv420p test/resources/testsrc2_10min.mp4 + + +def bench(f, average_over=50, warmup=2, gpu_monitoring=False, **f_kwargs): + for _ in range(warmup): + f(**f_kwargs) + + times = [] + utilizations = [] + memory_usage = [] + + for _ in range(average_over): + start = perf_counter_ns() + f(**f_kwargs) + end = perf_counter_ns() + times.append(end - start) + + if gpu_monitoring: + util = pynvml.nvmlDeviceGetEncoderUtilization(handle)[0] + mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle) + mem_used = mem_info.used / (1_000_000) # Convert bytes to MB + utilizations.append(util) + memory_usage.append(mem_used) + + times_tensor = torch.tensor(times).float() + return times_tensor, { + "utilization": torch.tensor(utilizations).float() if gpu_monitoring else None, + "memory_used": torch.tensor(memory_usage).float() if gpu_monitoring else None, + } + + +def report_stats(times, num_frames, nvenc_metrics=None, prefix="", unit="ms"): + fps = num_frames * 1e9 / times.median() + + mul = { + "ns": 1, + "µs": 1e-3, + "ms": 1e-6, + "s": 1e-9, + }[unit] + unit_times = times * mul + med = unit_times.median().item() + max = unit_times.max().item() + print(f"\n{prefix} {med = :.2f} {unit}, {max = :.2f} {unit}, fps = {fps:.1f}") + + if nvenc_metrics is not None: + mem_used_max = nvenc_metrics["memory_used"].max().item() + mem_used_median = nvenc_metrics["memory_used"].median().item() + util_max = nvenc_metrics["utilization"].max().item() + + print( + f"GPU memory used: med = {mem_used_median:.1f} MB, max = {mem_used_max:.1f} MB" + ) + print( + f"NVENC utilization: med = {nvenc_metrics['utilization'].median():.1f}%, max = {util_max:.1f}%" + ) + + +def encode_torchcodec(frames, output_path, device="cpu"): + encoder = VideoEncoder(frames=frames, frame_rate=FRAME_RATE) + if device == "cuda": + encoder.to_file(dest=output_path, codec="h264_nvenc", extra_options={"qp": 0}) + else: + encoder.to_file(dest=output_path, codec="libx264", crf=0) + + +def write_raw_frames(frames, raw_path): + # Convert NCHW to NHWC for raw video format + raw_frames = frames.permute(0, 2, 3, 1) + with open(raw_path, "wb") as f: + f.write(raw_frames.cpu().numpy().tobytes()) + + +def encode_ffmpeg_cli( + frames, raw_path, output_path, device="cpu", skip_write_frames=False +): + # Write frames during benchmarking function by default unless skip_write_frames flag used + if not skip_write_frames: + write_raw_frames(frames, raw_path) + + ffmpeg_cmd = [ + "ffmpeg", + "-y", + "-f", + "rawvideo", + "-pix_fmt", + "rgb24", + "-s", + f"{frames.shape[3]}x{frames.shape[2]}", + "-r", + str(FRAME_RATE), + "-i", + raw_path, + "-c:v", + "h264_nvenc" if device == "cuda" else "libx264", + "-pix_fmt", + "yuv420p", + ] + ffmpeg_cmd.extend(["-qp", "0"] if device == "cuda" else ["-crf", "0"]) + ffmpeg_cmd.extend([str(output_path)]) + subprocess.run(ffmpeg_cmd, check=True, capture_output=True) + + +def main(): + parser = ArgumentParser() + parser.add_argument( + "--path", type=str, help="Path to input video file", default=DEFAULT_VIDEO_PATH + ) + parser.add_argument( + "--average-over", + type=int, + default=30, + help="Number of runs to average over", + ) + parser.add_argument( + "--max-frames", + type=int, + default=None, + help="Maximum number of frames to decode for benchmarking. By default, all frames will be decoded.", + ) + parser.add_argument( + "--skip-write-frames", + action="store_true", + help="Do not write raw frames in FFmpeg CLI benchmarks", + ) + args = parser.parse_args() + decoder = VideoDecoder(str(args.path)) + frames = decoder.get_frames_in_range(start=0, stop=args.max_frames).data + + cuda_available = torch.cuda.is_available() + if not cuda_available: + print("CUDA not available. GPU benchmarks will be skipped.") + + print( + f"Benchmarking {len(frames)} frames from {Path(args.path).name} over {args.average_over} runs:" + ) + gpu_frames = frames.cuda() if cuda_available else None + print( + f"Decoded {frames.shape[0]} frames of size {frames.shape[2]}x{frames.shape[3]}" + ) + + temp_dir = Path(tempfile.mkdtemp()) + raw_frames_path = temp_dir / "input_frames.raw" + + # If skip_write_frames is True, we will not benchmark the time it takes to write the frames. + # Here, we still write the frames for FFmpeg to use! + if args.skip_write_frames: + write_raw_frames(frames, str(raw_frames_path)) + + if cuda_available: + # Benchmark torchcodec on GPU + gpu_output = temp_dir / "torchcodec_gpu.mp4" + times, nvenc_metrics = bench( + encode_torchcodec, + frames=gpu_frames, + output_path=str(gpu_output), + device="cuda", + gpu_monitoring=True, + average_over=args.average_over, + ) + report_stats( + times, frames.shape[0], nvenc_metrics, prefix="VideoEncoder on GPU" + ) + # Benchmark FFmpeg CLI on GPU + ffmpeg_gpu_output = temp_dir / "ffmpeg_gpu.mp4" + times, nvenc_metrics = bench( + encode_ffmpeg_cli, + frames=gpu_frames, + raw_path=str(raw_frames_path), + output_path=str(ffmpeg_gpu_output), + device="cuda", + gpu_monitoring=True, + skip_write_frames=args.skip_write_frames, + average_over=args.average_over, + ) + prefix = "FFmpeg CLI on GPU " + report_stats(times, frames.shape[0], nvenc_metrics, prefix=prefix) + + # Benchmark torchcodec on CPU + cpu_output = temp_dir / "torchcodec_cpu.mp4" + times, _nvenc_metrics = bench( + encode_torchcodec, + frames=frames, + output_path=str(cpu_output), + device="cpu", + average_over=args.average_over, + ) + report_stats(times, frames.shape[0], prefix="VideoEncoder on CPU") + + # Benchmark FFmpeg CLI on CPU + ffmpeg_cpu_output = temp_dir / "ffmpeg_cpu.mp4" + times, _nvenc_metrics = bench( + encode_ffmpeg_cli, + frames=frames, + raw_path=str(raw_frames_path), + output_path=str(ffmpeg_cpu_output), + device="cpu", + skip_write_frames=args.skip_write_frames, + average_over=args.average_over, + ) + prefix = "FFmpeg CLI on CPU " + report_stats(times, frames.shape[0], prefix=prefix) + + shutil.rmtree(temp_dir, ignore_errors=True) + + +if __name__ == "__main__": + main() diff --git a/docs/requirements.txt b/docs/requirements.txt index ba6848490..5ac0663e1 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,12 +1,15 @@ -sphinx-gallery>0.11 -sphinx==5.0.0 -sphinx_design +sphinx==7.2.6 +-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@pytorch_sphinx_theme2#egg=pytorch_sphinx_theme2 +sphinx-gallery>=0.14.0 +sphinx_design>=0.6.1 sphinx_copybutton sphinx-tabs +sphinx-sitemap>=2.7.1 +sphinxcontrib-mermaid>=1.0.0 +docutils>=0.18.1,<0.21 matplotlib torchvision ipython fsspec aiohttp joblib --e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme diff --git a/docs/source/_static/css/custom_torchcodec.css b/docs/source/_static/css/custom_torchcodec.css deleted file mode 100644 index 6c702e1f2..000000000 --- a/docs/source/_static/css/custom_torchcodec.css +++ /dev/null @@ -1,192 +0,0 @@ -/** - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -/* sphinx-design styles for cards/tabs */ - - -:root { - --sd-color-info: #ee4c2c; - --sd-color-primary: #6c6c6d; - --sd-color-primary-highlight: #f3f4f7; - --sd-color-card-border-hover: #ee4c2c; - --sd-color-card-border: #f3f4f7; - --sd-color-card-background: #fff; - --sd-color-card-text: inherit; - --sd-color-card-header: transparent; - --sd-color-card-footer: transparent; - --sd-color-tabs-label-active: #ee4c2c; - --sd-color-tabs-label-hover: #ee4c2c; - --sd-color-tabs-label-inactive: #6c6c6d; - --sd-color-tabs-underline-active: #ee4c2c; - --sd-color-tabs-underline-hover: #fabdbd; - --sd-color-tabs-underline-inactive: transparent; - --sd-color-tabs-overline: rgb(222, 222, 222); - --sd-color-tabs-underline: rgb(222, 222, 222); -} - -.sd-text-info { - color: #ee4c2c; -} - -.sd-card-img-top { - background: #ee4c2c; - height: 5px !important; -} - -.sd-card { - position: relative; - background-color: #fff; - opacity: 1.0; - border-radius: 0px; - width: 30%; - border: none; - padding-bottom: 0px; -} - - -.sd-card-img:hover { - opacity: 1.0; - background-color: #f3f4f7; -} - - -.sd-card:after { - display: block; - opacity: 1; - content: ''; - border-bottom: solid 1px #ee4c2c; - background-color: #fff; - transform: scaleX(0); - transition: transform .250s ease-in-out; - transform-origin: 0% 50%; -} - -.sd-card:hover { - background-color: #fff; - opacity: 1; - border-top: 1px solid #f3f4f7; - border-left: 1px solid #f3f4f7; - border-right: 1px solid #f3f4f7; -} - -.sd-card:hover:after { - transform: scaleX(1); -} - -.card-prerequisites:hover { - transition: none; - border: none; -} - -.card-prerequisites:hover:after { - transition: none; - transform: none; -} - -.card-prerequisites:after { - display: block; - content: ''; - border-bottom: none; - background-color: #fff; - transform: none; - transition: none; - transform-origin: none; -} - - -details.sd-dropdown { - font-weight: 300; - width: auto; -} - -details.sd-dropdown:after { - border: none; - transition: none; -} - -details.sd-dropdown:hover { - border: none; - transition: none; -} - -details.sd-dropdown .sd-summary-content { - font-weight: 300; -} - -details.sd-dropdown .highlight .n { - font-weight: normal; -} - -.et-page-column1 { - float: left; - width: 70%; - font-size: 1rem; -} - -.et-page-column2 { - float: right; - padding-top: 40px; - padding-left: 60px; - padding-right: 60px; - padding-bottom: 60px; - width: 30%; -} - -.et-page-column-row:after { - content: ""; - display: table; - clear: both; -} - -/* For screens smaller than 768px (typical mobile devices) */ -@media screen and (max-width: 768px) { - .et-page-column1, .et-page-column2 { - float: none; /* Remove floats */ - width: 100%; /* Full width for both columns */ - padding: 0; - font-size: 1rem; - } - - .et-page-column2 img { - display: none; - } - .et-page-column-row:after { - content: ""; - display: table; - clear: both; - } -} - -article.pytorch-article .class .method dt { - border-top: none; -} - -article.pytorch-article .class .simple dt { - border-top: none; -} - -article.pytorch-article .function dt.sig { - border-top: none; -} - -/* Fix for Sphinx gallery thumbnails. -See https://github.com/sphinx-gallery/sphinx-gallery/issues/990 -*/ -article.pytorch-article .sphx-glr-thumbnails .sphx-glr-thumbcontainer { - width: unset; - margin-right: 0; - margin-left: 0; -} -article.pytorch-article div.section div.wy-table-responsive tbody td { - width: 50%; -} - -article.pytorch-article section#glossary dl.simple.glossary dt { - font-weight: bold; - font-size: x-large; -} diff --git a/docs/source/_static/thumbnails/grumps_6.jpg b/docs/source/_static/thumbnails/grumps_6.jpg new file mode 100644 index 000000000..081764555 Binary files /dev/null and b/docs/source/_static/thumbnails/grumps_6.jpg differ diff --git a/docs/source/_static/thumbnails/grumps_audio.jpg b/docs/source/_static/thumbnails/grumps_audio.jpg new file mode 100644 index 000000000..44fffe445 Binary files /dev/null and b/docs/source/_static/thumbnails/grumps_audio.jpg differ diff --git a/docs/source/_static/thumbnails/grumps_audio2.jpg b/docs/source/_static/thumbnails/grumps_audio2.jpg new file mode 100644 index 000000000..ff2a3c47a Binary files /dev/null and b/docs/source/_static/thumbnails/grumps_audio2.jpg differ diff --git a/docs/source/_static/thumbnails/grumps_brrrr.jpg b/docs/source/_static/thumbnails/grumps_brrrr.jpg new file mode 100644 index 000000000..fa988b07e Binary files /dev/null and b/docs/source/_static/thumbnails/grumps_brrrr.jpg differ diff --git a/docs/source/_static/thumbnails/grumps_frame_mappings.jpg b/docs/source/_static/thumbnails/grumps_frame_mappings.jpg new file mode 100644 index 000000000..465174aaa Binary files /dev/null and b/docs/source/_static/thumbnails/grumps_frame_mappings.jpg differ diff --git a/docs/source/_static/thumbnails/grumps_parallel.jpg b/docs/source/_static/thumbnails/grumps_parallel.jpg new file mode 100644 index 000000000..2be5015f4 Binary files /dev/null and b/docs/source/_static/thumbnails/grumps_parallel.jpg differ diff --git a/docs/source/_static/thumbnails/grumps_seek_mode.jpg b/docs/source/_static/thumbnails/grumps_seek_mode.jpg new file mode 100644 index 000000000..402eddc67 Binary files /dev/null and b/docs/source/_static/thumbnails/grumps_seek_mode.jpg differ diff --git a/docs/source/_static/thumbnails/not_grumps_encoding_video.jpg b/docs/source/_static/thumbnails/not_grumps_encoding_video.jpg new file mode 100644 index 000000000..2720367e9 Binary files /dev/null and b/docs/source/_static/thumbnails/not_grumps_encoding_video.jpg differ diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html deleted file mode 100644 index 010a4d6d7..000000000 --- a/docs/source/_templates/layout.html +++ /dev/null @@ -1,21 +0,0 @@ -{% extends "!layout.html" %} - -{% block sidebartitle %} -
- {{ version }} ▼ -
- {% include "searchbox.html" %} -{% endblock %} - - -{% block footer %} - - - - -{% endblock %} diff --git a/docs/source/api_ref.rst b/docs/source/api_ref.rst new file mode 100644 index 000000000..f4ffe34dc --- /dev/null +++ b/docs/source/api_ref.rst @@ -0,0 +1,11 @@ +API Reference +============= + +.. toctree:: + :maxdepth: 1 + + api_ref_torchcodec + api_ref_decoders + api_ref_encoders + api_ref_samplers + api_ref_transforms diff --git a/docs/source/api_ref_decoders.rst b/docs/source/api_ref_decoders.rst index 1417d7aea..40ae75101 100644 --- a/docs/source/api_ref_decoders.rst +++ b/docs/source/api_ref_decoders.rst @@ -19,17 +19,30 @@ For an audio decoder tutorial, see: :ref:`sphx_glr_generated_examples_decoding_a VideoDecoder AudioDecoder +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: dataclass.rst + + VideoStreamMetadata + AudioStreamMetadata + + +CUDA decoding utils +------------------- + .. autosummary:: :toctree: generated/ :nosignatures: :template: function.rst set_cuda_backend + set_nvdec_cache_capacity + get_nvdec_cache_capacity .. autosummary:: :toctree: generated/ :nosignatures: :template: dataclass.rst - VideoStreamMetadata - AudioStreamMetadata + CpuFallbackStatus diff --git a/docs/source/api_ref_encoders.rst b/docs/source/api_ref_encoders.rst index 52c7295bc..6c7fc825d 100644 --- a/docs/source/api_ref_encoders.rst +++ b/docs/source/api_ref_encoders.rst @@ -16,3 +16,4 @@ For an audio decoder tutorial, see: :ref:`sphx_glr_generated_examples_encoding_a :template: class.rst AudioEncoder + VideoEncoder diff --git a/docs/source/api_ref_transforms.rst b/docs/source/api_ref_transforms.rst new file mode 100644 index 000000000..18bffabae --- /dev/null +++ b/docs/source/api_ref_transforms.rst @@ -0,0 +1,21 @@ +.. _transforms: + +===================== +torchcodec.transforms +===================== + +.. automodule:: torchcodec.transforms + +.. currentmodule:: torchcodec.transforms + +For a tutorial, see: :ref:`sphx_glr_generated_examples_decoding_transforms.py`. + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: dataclass.rst + + DecoderTransform + CenterCrop + RandomCrop + Resize diff --git a/docs/source/conf.py b/docs/source/conf.py index ba5247372..34e188dd9 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -18,21 +18,13 @@ # All configuration values have a default; values that are commented out # serve to show the default. -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# -# import os -# import sys -# sys.path.insert(0, os.path.abspath('.')) - import os import sys -import pytorch_sphinx_theme +import pytorch_sphinx_theme2 import torchcodec -sys.path.append(os.path.abspath(".")) +sys.path.insert(0, os.path.abspath(".")) # -- General configuration ------------------------------------------------ @@ -55,6 +47,9 @@ "sphinx_tabs.tabs", "sphinx_design", "sphinx_copybutton", + "sphinx_sitemap", + "sphinxcontrib.mermaid", + "pytorch_sphinx_theme2", ] @@ -81,12 +76,15 @@ def __call__(self, filename): "approximate_mode.py", "sampling.py", "parallel_decoding.py", + "performance_tips.py", "custom_frame_mappings.py", + "transforms.py", ] else: assert "examples/encoding" in self.src_dir order = [ "audio_encoding.py", + "video_encoding.py", ] try: @@ -133,13 +131,18 @@ def __call__(self, filename): # Add any paths that contain templates here, relative to this directory. -templates_path = ["_templates"] +templates_path = [ + "_templates", + os.path.join(os.path.dirname(pytorch_sphinx_theme2.__file__), "templates"), +] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # source_suffix = [".rst"] +version = ".".join(torchcodec.__version__.split(".")[:2]) + html_title = f"TorchCodec {torchcodec.__version__} Documentation" # The master toctree document. @@ -173,26 +176,51 @@ def __call__(self, filename): # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = "pytorch_sphinx_theme" -html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] +html_theme = "pytorch_sphinx_theme2" +html_theme_path = [pytorch_sphinx_theme2.get_html_theme_path()] # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. # html_theme_options = { - "collapse_navigation": False, - "display_version": True, - "logo_only": True, - "pytorch_project": "docs", - "navigation_with_keys": True, + "navigation_with_keys": False, "analytics_id": "GTM-T8XT4PS", + "icon_links": [ + { + "name": "X", + "url": "https://x.com/PyTorch", + "icon": "fa-brands fa-x-twitter", + }, + { + "name": "GitHub", + "url": "https://github.com/meta-pytorch/torchcodec", + "icon": "fa-brands fa-github", + }, + { + "name": "Discourse", + "url": "https://dev-discuss.pytorch.org/", + "icon": "fa-brands fa-discourse", + }, + { + "name": "PyPi", + "url": "https://pypi.org/project/torchcodec/", + "icon": "fa-brands fa-python", + }, + ], + "use_edit_page_button": True, + "navbar_center": "navbar-nav", + "navbar_start": ["navbar-logo", "version-switcher"], + "logo": { + "text": "TorchCodec", + }, + "switcher": { + "json_url": "https://meta-pytorch.org/torchcodec/torchcodec-versions.json", + "version_match": version, + }, + "show_version_warning_banner": True, } -html_logo = "_static/img/pytorch-logo-dark.svg" - -html_css_files = ["css/custom_torchcodec.css"] - # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". @@ -209,11 +237,38 @@ def __call__(self, filename): intersphinx_mapping = { "python": ("https://docs.python.org/3/", None), "torch": ("https://pytorch.org/docs/stable/", None), + "torchvision": ("https://docs.pytorch.org/vision/stable/", None), "numpy": ("https://numpy.org/doc/stable/", None), "PIL": ("https://pillow.readthedocs.io/en/stable/", None), "matplotlib": ("https://matplotlib.org/stable/", None), } +# html_context for theme2 +theme_variables = pytorch_sphinx_theme2.get_theme_variables() + +html_context = { + "theme_variables": theme_variables, + "display_github": True, + "github_url": "https://github.com", + "github_user": "meta-pytorch", + "github_repo": "torchcodec", + "feedback_url": "https://github.com/meta-pytorch/torchcodec", + "github_version": "main", + "doc_path": "docs/source", + "library_links": [], + "community_links": theme_variables.get("community_links", []), + "language_bindings_links": html_theme_options.get("language_bindings_links", []), +} + +# sitemap config +html_baseurl = "https://meta-pytorch.org/torchcodec/stable/" +sitemap_locales = [None] +sitemap_excludes = [ + "search.html", + "genindex.html", +] +sitemap_url_scheme = "{link}" + def inject_minigalleries(app, what, name, obj, options, lines): """Inject a minigallery into a docstring. diff --git a/docs/source/index.rst b/docs/source/index.rst index 85f9a067c..0276daa77 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -2,7 +2,7 @@ Welcome to the TorchCodec documentation! ======================================== TorchCodec is a Python library for decoding video and audio data into PyTorch -tensors, on CPU and CUDA GPU. It also supports audio encoding, and video encoding will come soon! +tensors, on CPU and CUDA GPU. It also supports audio and video encoding! It aims to be fast, easy to use, and well integrated into the PyTorch ecosystem. If you want to use PyTorch to train ML models on videos and audio, TorchCodec is how you turn these into data. @@ -25,8 +25,7 @@ Installation instructions .. grid-item-card:: :octicon:`file-code;1em` Installation instructions - :img-top: _static/img/card-background.svg - :link: https://github.com/pytorch/torchcodec?tab=readme-ov-file#installing-torchcodec + :link: https://github.com/meta-pytorch/torchcodec?tab=readme-ov-file#installing-torchcodec :link-type: url How to install TorchCodec @@ -38,7 +37,6 @@ Decoding .. grid-item-card:: :octicon:`file-code;1em` Getting Started with TorchCodec - :img-top: _static/img/card-background.svg :link: generated_examples/decoding/basic_example.html :link-type: url @@ -46,7 +44,6 @@ Decoding .. grid-item-card:: :octicon:`file-code;1em` Audio Decoding - :img-top: _static/img/card-background.svg :link: generated_examples/decoding/audio_decoding.html :link-type: url @@ -54,7 +51,6 @@ Decoding .. grid-item-card:: :octicon:`file-code;1em` GPU decoding - :img-top: _static/img/card-background.svg :link: generated_examples/decoding/basic_cuda_example.html :link-type: url @@ -62,7 +58,6 @@ Decoding .. grid-item-card:: :octicon:`file-code;1em` Streaming video - :img-top: _static/img/card-background.svg :link: generated_examples/decoding/file_like.html :link-type: url @@ -70,7 +65,6 @@ Decoding .. grid-item-card:: :octicon:`file-code;1em` Parallel decoding - :img-top: _static/img/card-background.svg :link: generated_examples/decoding/parallel_decoding.html :link-type: url @@ -78,12 +72,25 @@ Decoding .. grid-item-card:: :octicon:`file-code;1em` Clip sampling - :img-top: _static/img/card-background.svg :link: generated_examples/decoding/sampling.html :link-type: url How to sample regular and random clips from a video + .. grid-item-card:: :octicon:`file-code;1em` + Decoder transforms + :link: generated_examples/decoding/transforms.html + :link-type: url + + How to apply transforms while decoding + + .. grid-item-card:: :octicon:`file-code;1em` + Performance Tips + :link: generated_examples/decoding/performance_tips.html + :link-type: url + + Tips for optimizing video decoding performance + Encoding ^^^^^^^^ @@ -92,36 +99,38 @@ Encoding .. grid-item-card:: :octicon:`file-code;1em` Audio Encoding - :img-top: _static/img/card-background.svg :link: generated_examples/encoding/audio_encoding.html :link-type: url How encode audio samples + .. grid-item-card:: :octicon:`file-code;1em` + Video Encoding + :link: generated_examples/encoding/video_encoding.html + :link-type: url + + How to encode video frames + .. toctree:: :maxdepth: 1 - :caption: TorchCodec documentation :hidden: - Home - glossary + Installation .. toctree:: :maxdepth: 1 - :caption: Examples and tutorials :hidden: - Installation instructions generated_examples/index +.. toctree:: + :maxdepth: 1 + :hidden: + + api_ref .. toctree:: - :glob: :maxdepth: 1 - :caption: API Reference :hidden: - api_ref_torchcodec - api_ref_decoders - api_ref_encoders - api_ref_samplers + glossary diff --git a/examples/decoding/approximate_mode.py b/examples/decoding/approximate_mode.py index 62abee801..15a19321f 100644 --- a/examples/decoding/approximate_mode.py +++ b/examples/decoding/approximate_mode.py @@ -33,6 +33,7 @@ from time import perf_counter_ns +# sphinx_gallery_thumbnail_path = '_static/thumbnails/grumps_seek_mode.jpg' # Video source: https://www.pexels.com/video/dog-eating-854132/ # License: CC0. Author: Coverr. url = "https://videos.pexels.com/video-files/854132/854132-sd_640_360_25fps.mp4" @@ -66,7 +67,7 @@ # Performance: ``VideoDecoder`` creation # -------------------------------------- # -# In terms of performance, the ``seek_mode`` parameter ultimately affects the +# In terms of performance, the ``seek_mode`` parameter mainly affects the # **creation** of a :class:`~torchcodec.decoders.VideoDecoder` object. The # longer the video, the higher the performance gain. @@ -104,7 +105,7 @@ def bench(f, average_over=50, warmup=2, **f_kwargs): # --------------------------------------------- # # Strictly speaking the ``seek_mode`` parameter only affects the performance of -# the :class:`~torchcodec.decoders.VideoDecoder` creation. It does not have a +# the :class:`~torchcodec.decoders.VideoDecoder` creation. It usually does not have a # direct effect on the performance of frame decoding or sampling. **However**, # because frame decoding and sampling patterns typically involve the creation of # the :class:`~torchcodec.decoders.VideoDecoder` (one per video), ``seek_mode`` @@ -168,8 +169,10 @@ def sample_clips(seek_mode): # duration), and also builds an internal index of frames and key-frames. This # internal index is potentially more accurate than the one in the file's # headers, which leads to more accurate seeking behavior. -# Without the scan, TorchCodec relies only on the metadata contained in the -# file, which may not always be as accurate. +# Without the scan (in approximate mode), TorchCodec relies only on the metadata +# contained in the file, which may not always be as accurate. In some rare +# cases, relying on this less accurate data may also lead to slower frame +# decoding, because it can involve unnecessary seeks. # # Which mode should I use? # ------------------------ @@ -177,11 +180,10 @@ def sample_clips(seek_mode): # The general rule of thumb is as follows: # # - If you really care about exactness of frame seeking, use "exact". -# - If you can sacrifice exactness of seeking for speed, which is usually the -# case when doing clip sampling, use "approximate". -# - If your videos don't have variable framerate and their metadata is correct, -# then "approximate" mode is a net win: it will be just as accurate as the -# "exact" mode while still being significantly faster. +# - If your videos are short (less then a few minutes) then "exact" will usually +# be preferable, as the scan's fixed cost will be negligible. +# - For long videos, if you can sacrifice exactness of seeking for speed, which +# is usually the case when doing clip sampling, consider using "approximate". # %% shutil.rmtree(temp_dir) diff --git a/examples/decoding/audio_decoding.py b/examples/decoding/audio_decoding.py index 3d41e350d..95ac36082 100644 --- a/examples/decoding/audio_decoding.py +++ b/examples/decoding/audio_decoding.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + """ ======================================== Decoding audio streams with AudioDecoder @@ -25,6 +26,7 @@ def play_audio(samples): return Audio(samples.data, rate=samples.sample_rate) +# sphinx_gallery_thumbnail_path = '_static/thumbnails/grumps_audio.jpg' # Audio source is CC0: https://opengameart.org/content/town-theme-rpg # Attribution: cynicmusic.com pixelsphere.org url = "https://opengameart.org/sites/default/files/TownTheme.mp3" diff --git a/examples/decoding/basic_cuda_example.py b/examples/decoding/basic_cuda_example.py index 8f82940c0..45f9ea1a9 100644 --- a/examples/decoding/basic_cuda_example.py +++ b/examples/decoding/basic_cuda_example.py @@ -18,32 +18,10 @@ running the transform steps. Encoded packets are often much smaller than decoded frames so CUDA decoding also uses less PCI-e bandwidth. -When to and when not to use CUDA Decoding ------------------------------------------ - -CUDA Decoding can offer speed-up over CPU Decoding in a few scenarios: - -#. You are decoding a large resolution video -#. You are decoding a large batch of videos that's saturating the CPU -#. You want to do whole-image transforms like scaling or convolutions on the decoded tensors - after decoding -#. Your CPU is saturated and you want to free it up for other work - - -Here are situations where CUDA Decoding may not make sense: - -#. You want bit-exact results compared to CPU Decoding -#. You have small resolution videos and the PCI-e transfer latency is large -#. Your GPU is already busy and CPU is not - -It's best to experiment with CUDA Decoding to see if it improves your use-case. With -TorchCodec you can simply pass in a device parameter to the -:class:`~torchcodec.decoders.VideoDecoder` class to use CUDA Decoding. - Installing TorchCodec with CUDA Enabled --------------------------------------- -Refer to the installation guide in the `README `_. +Refer to the installation guide in the `README `_. """ @@ -113,6 +91,25 @@ print(frame.data.device) +# %% +# Checking for CPU Fallback +# ------------------------------------- +# +# In some cases, CUDA decoding may fall back to CPU decoding. This can happen +# when the video codec or format is not supported by the NVDEC hardware decoder, or when NVCUVID wasn't found. +# TorchCodec provides the :class:`~torchcodec.decoders.CpuFallbackStatus` class +# to help you detect when this fallback occurs. +# +# You can access the fallback status via the +# :attr:`~torchcodec.decoders.VideoDecoder.cpu_fallback` attribute: + +with set_cuda_backend("beta"): + decoder = VideoDecoder(video_file, device="cuda") + +# Check and print the CPU fallback status +print(decoder.cpu_fallback) + + # %% # Visualizing Frames # ------------------------------------- diff --git a/examples/decoding/basic_example.py b/examples/decoding/basic_example.py index 8440b6814..86fa8e6e4 100644 --- a/examples/decoding/basic_example.py +++ b/examples/decoding/basic_example.py @@ -18,7 +18,6 @@ # plotting utility. You can ignore that part and jump right below to # :ref:`creating_decoder`. -from typing import Optional import torch import requests @@ -33,7 +32,7 @@ raw_video_bytes = response.content -def plot(frames: torch.Tensor, title : Optional[str] = None): +def plot(frames: torch.Tensor, title: str | None = None): try: from torchvision.utils import make_grid from torchvision.transforms.v2.functional import to_pil_image diff --git a/examples/decoding/custom_frame_mappings.py b/examples/decoding/custom_frame_mappings.py index a62bc9eb0..1094201fc 100644 --- a/examples/decoding/custom_frame_mappings.py +++ b/examples/decoding/custom_frame_mappings.py @@ -32,6 +32,7 @@ import subprocess import requests +# sphinx_gallery_thumbnail_path = '_static/thumbnails/grumps_frame_mappings.jpg' # Video source: https://www.pexels.com/video/dog-eating-854132/ # License: CC0. Author: Coverr. url = "https://videos.pexels.com/video-files/854132/854132-sd_640_360_25fps.mp4" @@ -82,7 +83,15 @@ # Lets define a simple function to run ffprobe on a video's first stream index, then writes the results in output_json_path. def generate_frame_mappings(video_path, output_json_path, stream_index): - ffprobe_cmd = ["ffprobe", "-i", f"{video_path}", "-select_streams", f"{stream_index}", "-show_frames", "-show_entries", "frame=pts,duration,key_frame", "-of", "json"] + ffprobe_cmd = [ + "ffprobe", + "-i", f"{video_path}", + "-select_streams", f"{stream_index}", + "-show_frames", + "-show_entries", + "frame=pts,duration,key_frame", + "-of", "json", + ] print(f"Running ffprobe:\n{' '.join(ffprobe_cmd)}\n") ffprobe_result = subprocess.run(ffprobe_cmd, check=True, capture_output=True, text=True) with open(output_json_path, "w") as f: @@ -157,7 +166,7 @@ def bench(f, file_like=False, average_over=50, warmup=2, **f_kwargs): # so the performance benefits are realized. -def decode_frames(video_path, seek_mode = "exact", custom_frame_mappings = None): +def decode_frames(video_path, seek_mode="exact", custom_frame_mappings=None): decoder = VideoDecoder( source=video_path, seek_mode=seek_mode, diff --git a/examples/decoding/file_like.py b/examples/decoding/file_like.py index 7f302d3c5..238d51094 100644 --- a/examples/decoding/file_like.py +++ b/examples/decoding/file_like.py @@ -28,6 +28,7 @@ class to decode it. But all of the lessons here also apply to audio files and th from time import perf_counter_ns +# sphinx_gallery_thumbnail_path = '_static/thumbnails/grumps_6.jpg' def get_url_content(url): response = requests.get(url, headers={"User-Agent": ""}) if response.status_code != 200: diff --git a/examples/decoding/parallel_decoding.py b/examples/decoding/parallel_decoding.py index b5699a895..e8ad5e0b5 100644 --- a/examples/decoding/parallel_decoding.py +++ b/examples/decoding/parallel_decoding.py @@ -31,7 +31,6 @@ # require efficient processing. You can ignore that part and jump right below to # :ref:`start_parallel_decoding`. -from typing import List import torch import requests import tempfile @@ -44,6 +43,7 @@ from torchcodec.decoders import VideoDecoder +# sphinx_gallery_thumbnail_path = '_static/thumbnails/grumps_parallel.jpg' def bench(f, *args, num_exp=3, warmup=1, **kwargs): """Benchmark a function by running it multiple times and measuring execution time.""" for _ in range(warmup): @@ -74,7 +74,7 @@ def report_stats(times, unit="s"): return med -def split_indices(indices: List[int], num_chunks: int) -> List[List[int]]: +def split_indices(indices: list[int], num_chunks: int) -> list[list[int]]: """Split a list of indices into approximately equal chunks.""" chunk_size = len(indices) // num_chunks chunks = [] @@ -155,7 +155,8 @@ def generate_long_video(temp_dir: str): # Let's start with a sequential approach as our baseline. This processes # frames one by one without any parallelization. -def decode_sequentially(indices: List[int], video_path=long_video_path): + +def decode_sequentially(indices: list[int], video_path=long_video_path): """Decode frames sequentially using a single decoder instance.""" decoder = VideoDecoder(video_path, seek_mode="approximate") return decoder.get_frames_at(indices) @@ -173,8 +174,9 @@ def decode_sequentially(indices: List[int], video_path=long_video_path): # via the ``num_ffmpeg_threads`` parameter. This approach uses multiple # threads within FFmpeg itself to accelerate decoding operations. + def decode_with_ffmpeg_parallelism( - indices: List[int], + indices: list[int], num_threads: int, video_path=long_video_path ): @@ -197,10 +199,11 @@ def decode_with_ffmpeg_parallelism( # # Process-based parallelism distributes work across multiple Python processes. + def decode_with_multiprocessing( - indices: List[int], + indices: list[int], num_processes: int, - video_path=long_video_path + video_path=long_video_path, ): """Decode frames using multiple processes with joblib.""" chunks = split_indices(indices, num_chunks=num_processes) @@ -226,8 +229,9 @@ def decode_with_multiprocessing( # Thread-based parallelism uses multiple threads within a single process. # TorchCodec releases the GIL, so this can be very effective. + def decode_with_multithreading( - indices: List[int], + indices: list[int], num_threads: int, video_path=long_video_path ): diff --git a/examples/decoding/performance_tips.py b/examples/decoding/performance_tips.py new file mode 100644 index 000000000..132d7f96f --- /dev/null +++ b/examples/decoding/performance_tips.py @@ -0,0 +1,254 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +.. meta:: + :description: Learn how to optimize TorchCodec video decoding performance with batch APIs, approximate seeking, multi-threading, and CUDA acceleration. + +============================================== +TorchCodec Performance Tips and Best Practices +============================================== + +This tutorial consolidates performance optimization techniques for video +decoding with TorchCodec. Learn when and how to apply various strategies +to increase performance. +""" + +# %% +# Overview +# -------- +# +# When decoding videos with TorchCodec, several techniques can significantly +# improve performance depending on your use case. This guide covers: +# +# 1. **Batch APIs** - Decode multiple frames at once +# 2. **Approximate Mode & Keyframe Mappings** - Trade accuracy for speed +# 3. **Multi-threading** - Parallelize decoding across videos or chunks +# 4. **CUDA Acceleration** - Use GPU decoding for supported formats +# 5. **Decoder Native Transforms** - Apply transforms during decoding for memory efficiency +# +# We'll explore each technique and when to use it. + +# %% +# 1. Use Batch APIs When Possible +# -------------------------------- +# +# If you need to decode multiple frames at once, the batch methods are faster than calling single-frame decoding methods multiple times. +# For example, :meth:`~torchcodec.decoders.VideoDecoder.get_frames_at` is faster than calling :meth:`~torchcodec.decoders.VideoDecoder.get_frame_at` multiple times. +# TorchCodec's batch APIs reduce overhead and can leverage internal optimizations. +# +# **Key Methods:** +# +# For index-based frame retrieval: +# +# - :meth:`~torchcodec.decoders.VideoDecoder.get_frames_at` for specific indices +# - :meth:`~torchcodec.decoders.VideoDecoder.get_frames_in_range` for ranges +# +# For timestamp-based frame retrieval: +# +# - :meth:`~torchcodec.decoders.VideoDecoder.get_frames_played_at` for timestamps +# - :meth:`~torchcodec.decoders.VideoDecoder.get_frames_played_in_range` for time ranges +# +# **When to use:** +# +# - Decoding multiple frames + +# %% +# .. note:: +# +# For complete examples with runnable code demonstrating batch decoding, +# iteration, and frame retrieval, see :ref:`sphx_glr_generated_examples_decoding_basic_example.py` + +# %% +# 2. Approximate Mode & Keyframe Mappings +# ---------------------------------------- +# +# By default, TorchCodec uses ``seek_mode="exact"``, which performs a :term:`scan` when +# you create the decoder to build an accurate internal index of frames. This +# ensures frame-accurate seeking but takes longer for decoder initialization, +# especially on long videos. + +# %% +# **Approximate Mode** +# ~~~~~~~~~~~~~~~~~~~~ +# +# Setting ``seek_mode="approximate"`` skips the initial :term:`scan` and relies on the +# video file's metadata headers. This dramatically speeds up +# :class:`~torchcodec.decoders.VideoDecoder` creation, particularly for long +# videos, but may result in slightly less accurate seeking in some cases. +# +# +# **Which mode should you use:** +# +# - If you care about exactness of frame seeking, use “exact”. +# - If the video is long and you're only decoding a small amount of frames, approximate mode should be faster. + +# %% +# **Custom Frame Mappings** +# ~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# For advanced use cases, you can pre-compute a custom mapping between desired +# frame indices and actual keyframe locations. This allows you to speed up :class:`~torchcodec.decoders.VideoDecoder` +# instantiation while maintaining the frame seeking accuracy of ``seek_mode="exact"`` +# +# **When to use:** +# +# - Frame accuracy is critical, so you cannot use approximate mode +# - You can preprocess videos once and then decode them many times +# +# **Performance impact:** speeds up decoder instantiation, similarly to ``seek_mode="approximate"``. + +# %% +# .. note:: +# +# For complete benchmarks showing actual speedup numbers, accuracy comparisons, +# and implementation examples, see :ref:`sphx_glr_generated_examples_decoding_approximate_mode.py` +# and :ref:`sphx_glr_generated_examples_decoding_custom_frame_mappings.py` + +# %% +# 3. Multi-threading for Parallel Decoding +# ----------------------------------------- +# +# When decoding multiple videos or decoding a large number of frames from a single video, there are a few parallelization strategies to speed up the decoding process: +# +# - **FFmpeg-based parallelism** - Using FFmpeg's internal threading capabilities for intra-frame parallelism, where parallelization happens within individual frames rather than across frames. For that, use the `num_ffmpeg_threads` parameter of the :class:`~torchcodec.decoders.VideoDecoder` +# - **Multiprocessing** - Distributing work across multiple processes +# - **Multithreading** - Using multiple threads within a single process +# +# You can use both multiprocessing and multithreading to decode multiple videos in parallel, or to decode a single long video in parallel by splitting it into chunks. + +# %% +# .. note:: +# +# For complete examples comparing +# sequential, ffmpeg-based parallelism, multi-process, and multi-threaded approaches, see +# :ref:`sphx_glr_generated_examples_decoding_parallel_decoding.py` + +# %% +# 4. CUDA Acceleration +# -------------------- +# +# TorchCodec supports GPU-accelerated decoding using NVIDIA's hardware decoder +# (NVDEC) on supported hardware. This keeps decoded tensors in GPU memory, +# avoiding expensive CPU-GPU transfers for downstream GPU operations. +# +# **Recommended: use the Beta Interface!!** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# We recommend you use the new "beta" CUDA interface which is significantly faster than the previous one, and supports the same features: +# +# .. code-block:: python +# +# with set_cuda_backend("beta"): +# decoder = VideoDecoder("file.mp4", device="cuda") +# +# **When to use:** +# +# - Decoding large resolution videos +# - Large batch of videos saturating the CPU +# +# **When NOT to use:** +# +# - You need bit-exact results with CPU decoding +# - Small resolution videos and the PCI-e transfer latency is large +# - GPU is already busy and CPU is idle +# +# **Performance impact:** CUDA decoding can significantly outperform CPU decoding, +# especially for high-resolution videos and when decoding a lot of frames. +# Actual speedup varies by hardware, resolution, and codec. + +# %% +# **Checking for CPU Fallback** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# In some cases, CUDA decoding may silently fall back to CPU decoding when the +# video codec or format is not supported by NVDEC. You can detect this using +# the :attr:`~torchcodec.decoders.VideoDecoder.cpu_fallback` attribute: +# +# .. code-block:: python +# +# with set_cuda_backend("beta"): +# decoder = VideoDecoder("file.mp4", device="cuda") +# +# # Print detailed fallback status +# print(decoder.cpu_fallback) +# +# .. note:: +# +# The timing of when you can detect CPU fallback differs between backends: +# with the **FFmpeg backend**, you can only check fallback status after decoding at +# least one frame, because FFmpeg determines codec support lazily during decoding; +# with the **BETA backend**, you can check fallback status immediately after +# decoder creation, as the backend checks codec support upfront. +# +# For installation instructions, detailed examples, and visual comparisons +# between CPU and CUDA decoding, see :ref:`sphx_glr_generated_examples_decoding_basic_cuda_example.py` + +# %% +# 5. Decoder Native Transforms +# ---------------------------- +# +# TorchCodec supports applying transforms like resize and crop *during* the +# decoding process itself, rather than as a separate post-processing step. +# This can lead to significant memory savings, especially when decoding +# high-resolution videos that will be resized to smaller dimensions. +# +# :class:`~torchcodec.decoders.VideoDecoder` accepts both TorchCodec +# :class:`~torchcodec.transforms.DecoderTransform` objects and TorchVision +# :class:`~torchvision.transforms.v2.Transform` objects as transform +# specifications. TorchVision is **not required** to use decoder transforms. +# +# **Example:** +# +# .. code-block:: python +# +# from torchcodec.decoders import VideoDecoder +# from torchcodec.transforms import Resize +# +# decoder = VideoDecoder( +# "file.mp4", +# transforms=[Resize(size=(480, 640))] +# ) +# +# **When to use:** +# +# - If you are applying a transform pipeline that significantly reduces the +# dimensions of your input frames and memory efficiency matters. +# - If you are using multiple FFmpeg threads, decoder transforms may be faster. +# Experiment with your setup to verify. +# + +# %% +# .. note:: +# +# For complete examples with memory benchmarks, transform pipelines, and +# detailed comparisons between decoder transforms and TorchVision transforms, +# see :ref:`sphx_glr_generated_examples_decoding_transforms.py` + +# %% +# Conclusion +# ---------- +# +# TorchCodec offers multiple performance optimization strategies, each suited to +# different scenarios. Use batch APIs for multi-frame decoding, approximate mode +# for faster initialization, parallel processing for high throughput, CUDA +# acceleration to offload the CPU, and decoder native transforms for memory efficiency. +# +# The best results often come from combining techniques. Profile your specific +# use case and apply optimizations incrementally, using the benchmarks in the +# linked examples as a guide. +# +# For more information, see: +# +# - :ref:`sphx_glr_generated_examples_decoding_basic_example.py` - Basic decoding examples +# - :ref:`sphx_glr_generated_examples_decoding_approximate_mode.py` - Approximate mode benchmarks +# - :ref:`sphx_glr_generated_examples_decoding_custom_frame_mappings.py` - Custom frame mappings +# - :ref:`sphx_glr_generated_examples_decoding_parallel_decoding.py` - Parallel decoding strategies +# - :ref:`sphx_glr_generated_examples_decoding_basic_cuda_example.py` - CUDA acceleration guide +# - :ref:`sphx_glr_generated_examples_decoding_transforms.py` - Decoder transforms guide +# - :class:`torchcodec.decoders.VideoDecoder` - Full API reference + +# sphinx_gallery_thumbnail_path = '_static/thumbnails/grumps_brrrr.jpg' diff --git a/examples/decoding/sampling.py b/examples/decoding/sampling.py index 2ca3b6e50..19babbace 100644 --- a/examples/decoding/sampling.py +++ b/examples/decoding/sampling.py @@ -37,7 +37,7 @@ raw_video_bytes = response.content -def plot(frames: torch.Tensor, title : Optional[str] = None): +def plot(frames: torch.Tensor, title: str | None = None): try: from torchvision.utils import make_grid from torchvision.transforms.v2.functional import to_pil_image diff --git a/examples/decoding/transforms.py b/examples/decoding/transforms.py new file mode 100644 index 000000000..40eb7e79a --- /dev/null +++ b/examples/decoding/transforms.py @@ -0,0 +1,343 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +.. meta:: + :description: Learn how to apply transforms during video decoding for improved memory efficiency and performance. + +======================================================= +Decoder Transforms: Applying transforms during decoding +======================================================= + +In this example, we will demonstrate how to use the ``transforms`` parameter of +the :class:`~torchcodec.decoders.VideoDecoder` class. This parameter allows us +to specify a list of :class:`torchcodec.transforms.DecoderTransform` or +:class:`torchvision.transforms.v2.Transform` objects. These objects serve as +transform specifications that the :class:`~torchcodec.decoders.VideoDecoder` +will apply during the decoding process. +""" + +# %% +# First, a bit of boilerplate, definitions that we will use later. You can skip +# ahead to our :ref:`example_video` or :ref:`applying_transforms`. + + +import torch +import requests +import tempfile +from pathlib import Path +import shutil +from time import perf_counter_ns + + +def store_video_to(url: str, local_video_path: Path): + response = requests.get(url, headers={"User-Agent": ""}) + if response.status_code != 200: + raise RuntimeError(f"Failed to download video. {response.status_code = }.") + + with open(local_video_path, 'wb') as f: + for chunk in response.iter_content(): + f.write(chunk) + + +def plot(frames: torch.Tensor, title : str | None = None): + try: + from torchvision.utils import make_grid + from torchvision.transforms.v2.functional import to_pil_image + import matplotlib.pyplot as plt + except ImportError: + print("Cannot plot, please run `pip install torchvision matplotlib`") + return + + plt.rcParams["savefig.bbox"] = "tight" + dpi = 300 + fig, ax = plt.subplots(figsize=(800 / dpi, 600 / dpi), dpi=dpi) + ax.imshow(to_pil_image(make_grid(frames))) + ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) + if title is not None: + ax.set_title(title, fontsize=6) + plt.tight_layout() + +# %% +# .. _example_video: +# +# Our example video +# ----------------- +# +# We'll download a video from the internet and store it locally. We're +# purposefully retrieving a high resolution video to demonstrate using +# transforms to reduce the dimensions. + + +# Video source: https://www.pexels.com/video/an-african-penguin-at-the-beach-9140346/ +# Author: Taryn Elliott. +url = "https://videos.pexels.com/video-files/9140346/9140346-uhd_3840_2160_25fps.mp4" + +temp_dir = tempfile.mkdtemp() +penguin_video_path = Path(temp_dir) / "penguin.mp4" +store_video_to(url, penguin_video_path) + +from torchcodec.decoders import VideoDecoder +print(f"Penguin video metadata: {VideoDecoder(penguin_video_path).metadata}") + +# %% +# As shown above, the video is 37 seconds long and has a height of 2160 pixels +# and a width of 3840 pixels. +# +# .. note:: +# +# The colloquial way to report the dimensions of this video would be as +# 3840x2160; that is, (`width`, `height`). In the PyTorch ecosystem, image +# dimensions are typically expressed as (`height`, `width`). The remainder +# of this tutorial uses the PyTorch convention of (`height`, `width`) to +# specify image dimensions. + +# %% +# .. _applying_transforms: +# +# Applying transforms during pre-processing +# ----------------------------------------- +# +# A pre-processing pipeline for videos during training will typically apply a +# set of transforms for a variety of reasons. Below is a simple example of +# applying TorchVision's :class:`~torchvision.transforms.v2.Resize` transform to a single +# frame **after** the decoder returns it: + +from torchvision.transforms import v2 + +full_decoder = VideoDecoder(penguin_video_path) +frame = full_decoder[5] +resized_after = v2.Resize(size=(480, 640))(frame) + +plot(resized_after, title="Resized to 480x640 after decoding") + +# %% +# In the example above, ``full_decoder`` returns a video frame that has the +# dimensions (2160, 3840) which is then resized down to (480, 640). But with the +# ``transforms`` parameter of :class:`~torchcodec.decoders.VideoDecoder` we can +# specify for the resize to happen **during** decoding! + +resize_decoder = VideoDecoder( + penguin_video_path, + transforms=[v2.Resize(size=(480, 640))] +) +resized_during = resize_decoder[5] + +plot(resized_during, title="Resized to 480x640 during decoding") + +# %% +# Importantly, the two frames are not identical, even though we can see they +# *look* very similar: + +abs_diff = (resized_after.float() - resized_during.float()).abs() +(abs_diff == 0).all() + +# %% +# But they're close enough that models won't be able to tell a difference: +assert (abs_diff <= 1).float().mean() >= 0.998 + + +# %% +# TorchCodec's relationship to TorchVision transforms +# ----------------------------------------------------- +# Notably, in our examples we are passing in TorchVision +# :class:`~torchvision.transforms.v2.Transform` objects as our transforms. +# However, :class:`~torchcodec.decoders.VideoDecoder` accepts TorchVision +# transforms as a matter of convenience. TorchVision is **not required** to use +# decoder transforms. +# +# Every TorchVision transform that :class:`~torchcodec.decoders.VideoDecoder` accepts +# has a complementary transform defined in :mod:`torchcodec.transforms`. We +# would have gotten the same results if we had passed in the +# :class:`torchcodec.transforms.Resize` object that is a part of TorchCodec. +# :class:`~torchcodec.decoders.VideoDecoder` accepts both objects as a matter of +# convenience and to clarify the relationship between the transforms that TorchCodec +# applies and the transforms that TorchVision offers. +# +# While :class:`~torchcodec.decoders.VideoDecoder` accepts TorchVision transforms as +# *specifications*, it is not actually using the TorchVision implementation of these +# transforms. Instead, it is mapping them to equivalent +# `FFmpeg filters `_. That is, +# :class:`torchvision.transforms.v2.Resize` and :class:`torchcodec.transforms.Resize` are mapped to +# `scale `_; and +# :class:`torchvision.transforms.v2.CenterCrop` and :class:`torchcodec.transforms.CenterCrop` are mapped to +# `crop `_. +# +# The relationships we ensure between TorchCodec :class:`~torchcodec.transforms.DecoderTransform` objects +# and TorchVision :class:`~torchvision.transforms.v2.Transform` objects are: +# +# 1. The names are the same. +# 2. Default behaviors are the same. +# 3. The parameters for the :class:`~torchcodec.transforms.DecoderTransform` +# object are a subset of the TorchVision :class:`~torchvision.transforms.v2.Transform` +# object. +# 4. Parameters with the same name control the same behavior and accept a +# subset of the same types. +# 5. The difference between the frames returned by a decoder transform and +# the complementary TorchVision transform are such that a model should +# not be able to tell the difference. +# +# .. note:: +# +# Applying the exact same transforms during training and inference is +# important for model perforamnce. For example, if you use decoder +# transforms to resize frames during training, you should also use decoder +# transforms to resize frames during inference. We provide the similarity +# guarantees to mitigate the harm when the two techniques are +# *unintentionally* mixed. That is, if you use decoder transforms to resize +# frames during training, but use TorchVisions's +# :class:`~torchvision.transforms.v2.Resize` during inference, our guarantees +# mitigate the harm to model performance. But we **reccommend against** this kind of +# mixing. +# +# It is appropriate and expected to use some decoder transforms and some TorchVision +# transforms, as long as the exact same pre-processing operations are performed during +# training and inference. + +# %% +# Decoder transform pipelines +# --------------------------- +# So far, we've only provided a single transform to the ``transform`` parameter to +# :class:`~torchcodec.decoders.VideoDecoder`. But it +# actually accepts a list of transforms, which become a pipeline of transforms. +# The order of the list matters: the first transform in the list will receive +# the originally decoded frame. The output of that transform becomes the input +# to the next transform in the list, and so on. +# +# From now on, we'll use TorchCodec transforms instead of TorchVision +# transforms. When passed to the :class:`~torchcodec.decoders.VideoDecoder`, +# they behave identically. +# +# A simple example: + +from torchcodec.transforms import Resize, CenterCrop + + +crop_resize_decoder = VideoDecoder( + penguin_video_path, + transforms = [ + CenterCrop(size=(1280, 1664)), + Resize(size=(480, 640)), + ] +) +crop_resized_during = crop_resize_decoder[5] +plot(crop_resized_during, title="Center cropped then resized to 480x640") + +# %% +# Performance: memory efficiency and speed +# ---------------------------------------- +# +# The main motivation for decoder transforms is *memory efficiency*, +# particularly when applying transforms that reduce the size of a frame, such +# as resize and crop. Because the FFmpeg layer knows all of the transforms it +# needs to apply during decoding, it's able to efficiently reuse memory. +# Further, full resolution frames are never returned to the Python layer. As a +# result, there is significantly less total memory needed and less pressure on +# the Python garbage collector. +# +# In `benchmarks `_ +# reducing frames from (1080, 1920) down to (135, 240), we have observed a +# reduction in peak resident set size from 4.3 GB to 0.4 GB. +# +# There is sometimes a runtime benefit, but it is dependent on the number of +# threads that the :class:`~torchcodec.decoders.VideoDecoder` tells FFmpeg +# to use. We define the following benchmark function, as well as the functions +# to benchmark: + + +def bench(f, average_over=3, warmup=1, **f_kwargs): + for _ in range(warmup): + f(**f_kwargs) + + times = [] + for _ in range(average_over): + start_time = perf_counter_ns() + f(**f_kwargs) + end_time = perf_counter_ns() + times.append(end_time - start_time) + + times = torch.tensor(times) * 1e-6 # ns to ms + times_std = times.std().item() + times_med = times.median().item() + return f"{times_med = :.2f}ms +- {times_std:.2f}" + + +from torchcodec import samplers + + +def sample_decoder_transforms(num_threads: int): + decoder = VideoDecoder( + penguin_video_path, + transforms = [ + CenterCrop(size=(1280, 1664)), + Resize(size=(480, 640)), + ], + seek_mode="approximate", + num_ffmpeg_threads=num_threads, + ) + transformed_frames = samplers.clips_at_regular_indices( + decoder, + num_clips=1, + num_frames_per_clip=200 + ) + assert len(transformed_frames.data[0]) == 200 + + +def sample_torchvision_transforms(num_threads: int): + if num_threads > 0: + torch.set_num_threads(num_threads) + decoder = VideoDecoder( + penguin_video_path, + seek_mode="approximate", + num_ffmpeg_threads=num_threads, + ) + frames = samplers.clips_at_regular_indices( + decoder, + num_clips=1, + num_frames_per_clip=200 + ) + transforms = v2.Compose( + [ + v2.CenterCrop(size=(1280, 1664)), + v2.Resize(size=(480, 640)), + ] + ) + transformed_frames = transforms(frames.data) + assert transformed_frames.shape[1] == 200 + +# %% +# When the :class:`~torchcodec.decoders.VideoDecoder` object sets the number of +# FFmpeg threads to 0, that tells FFmpeg to determine how many threads to use +# based on what is available on the current system. In such cases, decoder transforms +# will tend to outperform getting back a full frame and applying TorchVision transforms +# sequentially: + + +print(f"decoder transforms: {bench(sample_decoder_transforms, num_threads=0)}") +print(f"torchvision transform: {bench(sample_torchvision_transforms, num_threads=0)}") + +# %% +# The reason is that FFmpeg is applying the decoder transforms in parallel. +# However, if the number of threads is 1 (as is the default), then there is often +# less benefit to using decoder transforms. Using the TorchVision transforms may +# even be faster! + +print(f"decoder transforms: {bench(sample_decoder_transforms, num_threads=1)}") +print(f"torchvision transform: {bench(sample_torchvision_transforms, num_threads=1)}") + +# %% +# In brief, our performance guidance is: +# +# 1. If you are applying a transform pipeline that signficantly reduces +# the dimensions of your input frames and memory efficiency matters, use +# decoder transforms. +# 2. If you are using multiple FFmpeg threads, decoder transforms may be +# faster. Experiment with your setup to verify. +# 3. If you are using a single FFmpeg thread, then decoder transforms may +# be slower. Experiment with your setup to verify. + +shutil.rmtree(temp_dir) +# %% diff --git a/examples/encoding/audio_encoding.py b/examples/encoding/audio_encoding.py index 8bcc1e305..a657512b2 100644 --- a/examples/encoding/audio_encoding.py +++ b/examples/encoding/audio_encoding.py @@ -20,6 +20,7 @@ from IPython.display import Audio as play_audio +# sphinx_gallery_thumbnail_path = '_static/thumbnails/grumps_audio2.jpg' def make_sinewave() -> tuple[torch.Tensor, int]: freq_A = 440 # Hz sample_rate = 16000 # Hz diff --git a/examples/encoding/video_encoding.py b/examples/encoding/video_encoding.py new file mode 100644 index 000000000..4c589127e --- /dev/null +++ b/examples/encoding/video_encoding.py @@ -0,0 +1,307 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +======================================= +Encoding video frames with VideoEncoder +======================================= + +In this example, we'll learn how to encode video frames to a file or to raw +bytes using the :class:`~torchcodec.encoders.VideoEncoder` class. +""" + +# %% +# First, we'll download a video and decode some frames to tensors. +# These will be the input to the :class:`~torchcodec.encoders.VideoEncoder`. For more details on decoding, +# see :ref:`sphx_glr_generated_examples_decoding_basic_example.py`. +# Otherwise, skip ahead to :ref:`creating_encoder`. + +import requests +from torchcodec.decoders import VideoDecoder +from IPython.display import Video + +# sphinx_gallery_thumbnail_path = '_static/thumbnails/not_grumps_encoding_video.jpg' + + +def play_video(encoded_bytes): + return Video( + data=encoded_bytes.numpy().tobytes(), + embed=True, + width=640, + height=360, + mimetype="video/mp4", + ) + + +# Video source: https://www.pexels.com/video/adorable-cats-on-the-lawn-4977395/ +# Author: Altaf Shah. +url = "https://videos.pexels.com/video-files/4977395/4977395-hd_1920_1080_24fps.mp4" + +response = requests.get(url, headers={"User-Agent": ""}) +if response.status_code != 200: + raise RuntimeError(f"Failed to download video. {response.status_code = }.") + +raw_video_bytes = response.content + +decoder = VideoDecoder(raw_video_bytes) +frames = decoder.get_frames_in_range(0, 60).data # Get first 60 frames +frame_rate = decoder.metadata.average_fps + +# %% +# .. _creating_encoder: +# +# Creating an encoder +# ------------------- +# +# Let's instantiate a :class:`~torchcodec.encoders.VideoEncoder`. We will need to provide +# the frames to be encoded as a 4D tensor of shape +# ``(num_frames, num_channels, height, width)`` with values in the ``[0, 255]`` +# range and ``torch.uint8`` dtype. We will also need to provide the frame rate of the input +# video. +# +# .. note:: +# +# The ``frame_rate`` parameter corresponds to the frame rate of the +# *input* video. It will also be used for the frame rate of the *output* encoded video. +from torchcodec.encoders import VideoEncoder + +print(f"{frames.shape = }, {frames.dtype = }") +print(f"{frame_rate = } fps") + +encoder = VideoEncoder(frames=frames, frame_rate=frame_rate) + +# %% +# .. _cuda_encoding: +# +# CUDA Encoding +# ------------- +# +# To encode on GPU, pass the frames as a CUDA tensor. This can result in significantly +# faster encoding than CPU. The encoder will automatically select a CUDA-compatible +# codec when frames are on a CUDA device, such as ``h264_nvenc`` or ``hevc_nvenc``. +# +# .. note:: +# +# On GPU, the pixel format is always set to ``nv12`` (which does equivalent chroma subsampling +# to ``yuv420p``). The ``pixel_format`` parameter is not supported for GPU encoding. +# +# .. code-block:: python +# +# gpu_frames = frames.to("cuda") # Move frames to GPU +# gpu_encoder = VideoEncoder(frames=gpu_frames, frame_rate=frame_rate) +# +# That's it! The rest of the encoding process is the same as on CPU. + +# %% +# Encoding to file, bytes, or file-like +# ------------------------------------- +# +# :class:`~torchcodec.encoders.VideoEncoder` supports encoding frames into a +# file via the :meth:`~torchcodec.encoders.VideoEncoder.to_file` method, to +# file-like objects via the :meth:`~torchcodec.encoders.VideoEncoder.to_file_like` +# method, or to raw bytes via :meth:`~torchcodec.encoders.VideoEncoder.to_tensor`. +# For now we will use :meth:`~torchcodec.encoders.VideoEncoder.to_tensor`, so we +# can easily inspect and display the encoded video. + +encoded_frames = encoder.to_tensor(format="mp4") +play_video(encoded_frames) + +# %% +# +# Now that we have encoded data, we can decode it back to verify the +# round-trip encode/decode process works as expected: + +decoder_verify = VideoDecoder(encoded_frames) +decoded_frames = decoder_verify.get_frames_in_range(0, 60).data + +print(f"Re-decoded video: {decoded_frames.shape = }") +print(f"Original frames: {frames.shape = }") + +# %% +# .. _codec_selection: +# +# Codec Selection +# --------------- +# +# By default, the codec used is selected automatically using the file extension provided +# in the ``dest`` parameter for the :meth:`~torchcodec.encoders.VideoEncoder.to_file` method, +# or using the ``format`` parameter for the +# :meth:`~torchcodec.encoders.VideoEncoder.to_file_like` and +# :meth:`~torchcodec.encoders.VideoEncoder.to_tensor` methods. +# +# For example, when encoding to MP4 format, the default codec is typically ``H.264``. +# +# To use a codec other than the default, use the ``codec`` parameter. +# You can specify either a specific codec implementation (e.g., ``"libx264"``) +# or a codec specification (e.g., ``"h264"``). Different codecs offer +# different tradeoffs between quality, file size, and encoding speed. +# +# .. note:: +# +# To see available encoders on your system, run ``ffmpeg -encoders``. +# +# Let's encode the same frames using different codecs: + +import tempfile +from pathlib import Path + +# H.264 encoding +h264_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name +encoder.to_file(h264_output, codec="libx264") + +# H.265 encoding +hevc_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name +encoder.to_file(hevc_output, codec="hevc") + +# Now let's use ffprobe to verify the codec used in the output files +import subprocess + +for output, name in [(h264_output, "h264_output"), (hevc_output, "hevc_output")]: + result = subprocess.run( + [ + "ffprobe", + "-v", + "error", + "-select_streams", + "v:0", + "-show_entries", + "stream=codec_name", + "-of", + "default=noprint_wrappers=1:nokey=1", + output, + ], + capture_output=True, + text=True, + ) + print(f"Codec used in {name}: {result.stdout.strip()}") + + +# %% +# .. _pixel_format: +# +# Pixel Format +# ------------ +# +# The ``pixel_format`` parameter controls the color sampling (chroma subsampling) +# of the output video. This affects both quality and file size. +# +# Common pixel formats: +# +# - ``"yuv420p"`` - 4:2:0 chroma subsampling (standard quality, smaller file size, widely compatible) +# - ``"yuv444p"`` - 4:4:4 chroma subsampling (full chroma resolution, higher quality, larger file size) +# +# Most playback devices and platforms support ``yuv420p``, making it the most +# common choice for video encoding. +# +# .. note:: +# +# Pixel format support depends on the codec used. Use ``ffmpeg -h encoder=`` +# to check available options for your selected codec. + +# Standard pixel format +yuv420_encoded_frames = encoder.to_tensor( + format="mp4", codec="libx264", pixel_format="yuv420p" +) +play_video(yuv420_encoded_frames) + +# %% +# .. _crf: +# +# CRF (Constant Rate Factor) +# -------------------------- +# +# The ``crf`` parameter controls video quality, where lower values produce higher quality output. +# +# For example, with the commonly used H.264 codec, ``libx264``: +# +# - Values range from 0 (lossless) to 51 (worst quality) +# - Values 17 or 18 are considered visually lossless, and the default is 23. +# +# .. note:: +# +# The range and interpretation of CRF values depend on the codec used, and +# not all codecs support CRF. Use ``ffmpeg -h encoder=`` to +# check available options for your selected codec. +# + +# High quality (low CRF) +high_quality_output = encoder.to_tensor(format="mp4", codec="libx264", crf=0) +play_video(high_quality_output) + +# %% + +# Low quality (high CRF) +low_quality_output = encoder.to_tensor(format="mp4", codec="libx264", crf=50) +play_video(low_quality_output) + + +# %% +# .. _preset: +# +# Preset +# ------ +# +# The ``preset`` parameter controls the tradeoff between encoding speed and file compression. +# Faster presets encode faster but produce larger files, while slower +# presets take more time to encode but result in better compression. +# +# For example, with the commonly used H.264 codec, ``libx264`` presets include +# ``"ultrafast"`` (fastest), ``"fast"``, ``"medium"`` (default), ``"slow"``, and +# ``"veryslow"`` (slowest, best compression). See the +# `H.264 Video Encoding Guide `_ +# for additional details. +# +# .. note:: +# +# Not all codecs support the ``presets`` option. Use ``ffmpeg -h encoder=`` +# to check available options for your selected codec. +# + +# Fast encoding with a larger file size +fast_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name +encoder.to_file(fast_output, codec="libx264", preset="ultrafast") +print(f"Size of fast encoded file: {Path(fast_output).stat().st_size} bytes") + +# Slow encoding for a smaller file size +slow_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name +encoder.to_file(slow_output, codec="libx264", preset="veryslow") +print(f"Size of slow encoded file: {Path(slow_output).stat().st_size} bytes") + +# %% +# .. _extra_options: +# +# Extra Options +# ------------- +# +# The ``extra_options`` parameter accepts a dictionary of codec-specific options +# that would normally be set via FFmpeg command-line arguments. This enables +# control of encoding settings beyond the common parameters. +# +# For example, some potential extra options for the commonly used H.264 codec, ``libx264`` include: +# +# - ``"g"`` - GOP (Group of Pictures) size / keyframe interval +# - ``"max_b_frames"`` - Maximum number of B-frames between I and P frames +# - ``"tune"`` - Tuning preset (e.g., ``"film"``, ``"animation"``, ``"grain"``) +# +# .. note:: +# +# Use ``ffmpeg -h encoder=`` to see all available options for +# a specific codec. +# + + +custom_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name +encoder.to_file( + custom_output, + codec="libx264", + extra_options={ + "g": 50, # Keyframe every 50 frames + "max_b_frames": 0, # Disable B-frames for faster decoding + "tune": "fastdecode", # Optimize for fast decoding + } +) + +# %% diff --git a/mypy.ini b/mypy.ini index bd0ee6ac8..f018ba4f8 100644 --- a/mypy.ini +++ b/mypy.ini @@ -4,3 +4,4 @@ files = src/torchcodec show_error_codes = True pretty = True allow_redefinition = True +follow_untyped_imports = True diff --git a/packaging/install_ffmpeg.sh b/packaging/install_ffmpeg.sh new file mode 100755 index 000000000..32907596a --- /dev/null +++ b/packaging/install_ffmpeg.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This script installs FFmpeg from conda-forge after asserting that FFmpeg is +# not already installed. +# +# Usage: +# install_ffmpeg.sh FFMPEG_VERSION +# install_ffmpeg.sh 7.0.1 +# install_ffmpeg.sh 8.0 + +set -euo pipefail + +if [ $# -lt 1 ]; then + echo "Error: Missing required FFmpeg version" + echo "Usage: install_ffmpeg.sh FFMPEG_VERSION" + echo "Example: install_ffmpeg.sh 7.0.1" + exit 1 +fi + +FFMPEG_VERSION="$1" + +# Ideally we would have checked for that before installing the wheel, +# but we need to checkout the repo to access this file, and we don't +# want to checkout the repo before installing the wheel to avoid any +# side-effect. It's OK. +source packaging/helpers.sh +assert_ffmpeg_not_installed + +echo "Installing FFmpeg version $FFMPEG_VERSION from conda-forge..." +conda install "ffmpeg=$FFMPEG_VERSION" -c conda-forge +ffmpeg -version diff --git a/packaging/install_pytorch.sh b/packaging/install_pytorch.sh new file mode 100755 index 000000000..23611e938 --- /dev/null +++ b/packaging/install_pytorch.sh @@ -0,0 +1,39 @@ +#!/usr/bin/env bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This script installs PyTorch and other optional torch packages like +# torchvision from either the nightly or test channel based on the branch: test +# for release branches (and PRs against a release branch), nightly otherwise +# +# Example usage: +# install_pytorch.sh cpu "torch torchvision" +# install_pytorch.sh cpu "torch" +# install_pytorch.sh cu126 "torch torchvision" + +set -euo pipefail + +if [ $# -lt 2 ]; then + echo "Error: Missing required arguments" + echo "Usage: install_pytorch.sh COMPUTE_PLATFORM PACKAGES" + echo "Example: install_pytorch.sh cpu \"torch torchvision\"" + exit 1 +fi + +COMPUTE_PLATFORM="$1" +PACKAGES="$2" + +if [[ (${GITHUB_EVENT_NAME:-} = 'pull_request' && (${GITHUB_BASE_REF:-} = 'release'*)) || (${GITHUB_REF:-} = 'refs/heads/release'*) || (${GITHUB_REF:-} = refs/tags/v*) ]]; then + CHANNEL=test +else + CHANNEL=nightly +fi + +echo "Installing PyTorch packages: $PACKAGES" +echo "Compute platform: $COMPUTE_PLATFORM" +echo "Channel: $CHANNEL" + +python -m pip install --pre $PACKAGES --index-url https://download.pytorch.org/whl/${CHANNEL}/${COMPUTE_PLATFORM} diff --git a/packaging/install_test_dependencies.sh b/packaging/install_test_dependencies.sh new file mode 100755 index 000000000..69c2d6dcb --- /dev/null +++ b/packaging/install_test_dependencies.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This script installs the test dependencies needed to run the test suite. +# +# Example usage: +# install_test_dependencies.sh + +set -euo pipefail + +echo "Installing test dependencies..." +# Ideally we would find a way to get those dependencies from pyproject.toml +python -m pip install numpy pytest pillow + +echo "Test dependencies installed successfully!" diff --git a/packaging/install_torchcodec_wheel.sh b/packaging/install_torchcodec_wheel.sh new file mode 100755 index 000000000..77b7b1383 --- /dev/null +++ b/packaging/install_torchcodec_wheel.sh @@ -0,0 +1,31 @@ +#!/usr/bin/env bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This script finds and installs a torchcodec wheel from the dist directory. The +# wheel is expected to have been built and downloaded from a separate job. +# +# Usage: +# install_torchcodec_wheel.sh [WHEEL_PATTERN] +# +# Example usage: +# install_torchcodec_wheel.sh +# install_torchcodec_wheel.sh "*.whl" +# install_torchcodec_wheel.sh "*cu126-cp310*.whl" + +set -euo pipefail + +WHEEL_PATTERN="${1:-*.whl}" + +wheel_path=$(find dist -type f -name "$WHEEL_PATTERN") + +if [ -z "$wheel_path" ]; then + echo "Error: No wheel found matching pattern '$WHEEL_PATTERN' in dist/" + exit 1 +fi + +echo "Installing $wheel_path" +python -m pip install "$wheel_path" -vvv diff --git a/packaging/remove_src.sh b/packaging/remove_src.sh new file mode 100755 index 000000000..e3b2f43b6 --- /dev/null +++ b/packaging/remove_src.sh @@ -0,0 +1,31 @@ +#!/usr/bin/env bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This script removes the src/ directory to ensure tests run against the +# installed wheel rather than local source code. +# +# Usage: +# remove_src.sh + +set -euo pipefail + +echo "Deleting src/ folder to ensure tests use installed wheel..." +# The only reason we checked-out the repo is to get access to the +# tests and to the helper scripts for the CI. We don't care about the rest. +# Out of precaution, we delete +# the src/ folder to be extra sure that we're running the code from +# the installed wheel rather than from the source. +# This is just to be extra cautious and very overkill because a) +# there's no way the `torchcodec` package from src/ can be found from +# the PythonPath: the main point of `src/` is precisely to protect +# against that and b) if we ever were to execute code from +# `src/torchcodec`, it would fail loudly because the built .so files +# aren't present there. +rm -r src/ +ls + +echo "src/ folder removed successfully!" diff --git a/pyproject.toml b/pyproject.toml index 6bdcd13f7..367786508 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ name = "paddlecodec" description = "A video decoder for PyTorch" readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.10" license-files = ["LICENSE"] authors = [ { name = "PaddlePaddle Team", email = "Paddle-better@baidu.com" }, @@ -32,7 +32,7 @@ dev = [ first_party_detection = false [tool.black] -target-version = ["py38"] +target-version = ["py310"] [tool.ufmt] @@ -46,6 +46,10 @@ markers = [ "slow: mark test as slow" ] +# Tells pytest not to run tests within this directory by default. +# These tests can still be run by manually specifying the path. +norecursedirs = ["third-party-interface"] + # We don't want to run the slow tests by default. These options are ignored in # the CI, where we definitely want the 'slow' tests to run. addopts = "-v -m 'not slow'" diff --git a/setup.py b/setup.py index b8211ea5e..63dd870ff 100644 --- a/setup.py +++ b/setup.py @@ -115,6 +115,9 @@ def _build_all_extensions_with_cmake(self): torchcodec_disable_compile_warning_as_error = os.environ.get( "TORCHCODEC_DISABLE_COMPILE_WARNING_AS_ERROR", "OFF" ) + torchcodec_disable_homebrew_rpath = os.environ.get( + "TORCHCODEC_DISABLE_HOMEBREW_RPATH", "OFF" + ) python_version = sys.version_info cmake_args = [ f"-DCMAKE_INSTALL_PREFIX={self._install_prefix}", @@ -125,6 +128,7 @@ def _build_all_extensions_with_cmake(self): f"-DPYTHON_VERSION={python_version.major}.{python_version.minor}", f"-DENABLE_CUDA={enable_cuda}", f"-DTORCHCODEC_DISABLE_COMPILE_WARNING_AS_ERROR={torchcodec_disable_compile_warning_as_error}", + f"-DTORCHCODEC_DISABLE_HOMEBREW_RPATH={torchcodec_disable_homebrew_rpath}", ] self.build_temp = os.getenv("TORCHCODEC_CMAKE_BUILD_DIR", self.build_temp) @@ -199,14 +203,12 @@ def _write_version_files(): # the content of `version.txt` plus some suffix like "+cpu" or "+cu112". # See # https://github.com/pytorch/test-infra/blob/61e6da7a6557152eb9879e461a26ad667c15f0fd/tools/pkg-helpers/pytorch_pkg_helpers/version.py#L113 - version = version.replace("+cpu", "") with open(_ROOT_DIR / "version.txt", "w") as f: f.write(f"{version}") else: with open(_ROOT_DIR / "version.txt") as f: version = f.readline().strip() try: - version = version.replace("+cpu", "") sha = ( subprocess.check_output( ["git", "rev-parse", "HEAD"], cwd=str(_ROOT_DIR) diff --git a/src/torchcodec/__init__.py b/src/torchcodec/__init__.py index 29131290f..c30bd93d2 100644 --- a/src/torchcodec/__init__.py +++ b/src/torchcodec/__init__.py @@ -4,10 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from pathlib import Path + # Note: usort wants to put Frame and FrameBatch after decoders and samplers, # but that results in circular import. from ._frame import AudioSamples, Frame, FrameBatch # usort:skip # noqa -from . import decoders, samplers # noqa +from . import decoders, encoders, samplers, transforms # noqa try: # Note that version.py is generated during install. diff --git a/src/torchcodec/_core/AVIOContextHolder.cpp b/src/torchcodec/_core/AVIOContextHolder.cpp index c1188e684..42fccf7c5 100644 --- a/src/torchcodec/_core/AVIOContextHolder.cpp +++ b/src/torchcodec/_core/AVIOContextHolder.cpp @@ -4,8 +4,8 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. -#include "src/torchcodec/_core/AVIOContextHolder.h" -#include +#include "AVIOContextHolder.h" +#include "StableABICompat.h" namespace facebook::torchcodec { @@ -16,20 +16,21 @@ void AVIOContextHolder::createAVIOContext( void* heldData, bool isForWriting, int bufferSize) { - TORCH_CHECK( + STD_TORCH_CHECK( bufferSize > 0, "Buffer size must be greater than 0; is " + std::to_string(bufferSize)); auto buffer = static_cast(av_malloc(bufferSize)); - TORCH_CHECK( + STD_TORCH_CHECK( buffer != nullptr, "Failed to allocate buffer of size " + std::to_string(bufferSize)); - TORCH_CHECK(seek != nullptr, "seek method must be defined"); + STD_TORCH_CHECK(seek != nullptr, "seek method must be defined"); if (isForWriting) { - TORCH_CHECK(write != nullptr, "write method must be defined for writing"); + STD_TORCH_CHECK( + write != nullptr, "write method must be defined for writing"); } else { - TORCH_CHECK(read != nullptr, "read method must be defined for reading"); + STD_TORCH_CHECK(read != nullptr, "read method must be defined for reading"); } avioContext_.reset(avioAllocContext( @@ -43,7 +44,7 @@ void AVIOContextHolder::createAVIOContext( if (!avioContext_) { av_freep(&buffer); - TORCH_CHECK(false, "Failed to allocate AVIOContext"); + STD_TORCH_CHECK(false, "Failed to allocate AVIOContext"); } } diff --git a/src/torchcodec/_core/AVIOContextHolder.h b/src/torchcodec/_core/AVIOContextHolder.h index 16d70beaf..7b1123e6d 100644 --- a/src/torchcodec/_core/AVIOContextHolder.h +++ b/src/torchcodec/_core/AVIOContextHolder.h @@ -6,7 +6,7 @@ #pragma once -#include "src/torchcodec/_core/FFMPEGCommon.h" +#include "FFMPEGCommon.h" namespace facebook::torchcodec { diff --git a/src/torchcodec/_core/AVIOFileLikeContext.cpp b/src/torchcodec/_core/AVIOFileLikeContext.cpp index 210942b57..1331abd5b 100644 --- a/src/torchcodec/_core/AVIOFileLikeContext.cpp +++ b/src/torchcodec/_core/AVIOFileLikeContext.cpp @@ -4,8 +4,8 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. -#include "src/torchcodec/_core/AVIOFileLikeContext.h" -#include +#include "AVIOFileLikeContext.h" +#include "StableABICompat.h" namespace facebook::torchcodec { @@ -20,16 +20,16 @@ AVIOFileLikeContext::AVIOFileLikeContext( py::gil_scoped_acquire gil; if (isForWriting) { - TORCH_CHECK( + STD_TORCH_CHECK( py::hasattr(fileLike, "write"), "File like object must implement a write method for writing."); } else { - TORCH_CHECK( + STD_TORCH_CHECK( py::hasattr(fileLike, "read"), "File like object must implement a read method for reading."); } - TORCH_CHECK( + STD_TORCH_CHECK( py::hasattr(fileLike, "seek"), "File like object must implement a seek method."); } @@ -60,7 +60,7 @@ int AVIOFileLikeContext::read(void* opaque, uint8_t* buf, int buf_size) { break; } - TORCH_CHECK( + STD_TORCH_CHECK( numBytesRead <= request, "Requested up to ", request, diff --git a/src/torchcodec/_core/AVIOFileLikeContext.h b/src/torchcodec/_core/AVIOFileLikeContext.h index fd7f534f3..001cda550 100644 --- a/src/torchcodec/_core/AVIOFileLikeContext.h +++ b/src/torchcodec/_core/AVIOFileLikeContext.h @@ -9,7 +9,7 @@ #include #include -#include "src/torchcodec/_core/AVIOContextHolder.h" +#include "AVIOContextHolder.h" namespace py = pybind11; diff --git a/src/torchcodec/_core/AVIOTensorContext.cpp b/src/torchcodec/_core/AVIOTensorContext.cpp index 263ce2228..5b1ac23ce 100644 --- a/src/torchcodec/_core/AVIOTensorContext.cpp +++ b/src/torchcodec/_core/AVIOTensorContext.cpp @@ -4,8 +4,8 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. -#include "src/torchcodec/_core/AVIOTensorContext.h" -#include +#include "AVIOTensorContext.h" +#include "StableABICompat.h" namespace facebook::torchcodec { @@ -17,7 +17,7 @@ constexpr int64_t MAX_TENSOR_SIZE = 320'000'000; // 320 MB // The signature of this function is defined by FFMPEG. int read(void* opaque, uint8_t* buf, int buf_size) { auto tensorContext = static_cast(opaque); - TORCH_CHECK( + STD_TORCH_CHECK( tensorContext->current_pos <= tensorContext->data.numel(), "Tried to read outside of the buffer: current_pos=", tensorContext->current_pos, @@ -28,7 +28,7 @@ int read(void* opaque, uint8_t* buf, int buf_size) { static_cast(buf_size), tensorContext->data.numel() - tensorContext->current_pos); - TORCH_CHECK( + STD_TORCH_CHECK( numBytesRead >= 0, "Tried to read negative bytes: numBytesRead=", numBytesRead, @@ -43,7 +43,8 @@ int read(void* opaque, uint8_t* buf, int buf_size) { std::memcpy( buf, - tensorContext->data.data_ptr() + tensorContext->current_pos, + tensorContext->data.const_data_ptr() + + tensorContext->current_pos, numBytesRead); tensorContext->current_pos += numBytesRead; return numBytesRead; @@ -55,7 +56,7 @@ int write(void* opaque, const uint8_t* buf, int buf_size) { int64_t bufSize = static_cast(buf_size); if (tensorContext->current_pos + bufSize > tensorContext->data.numel()) { - TORCH_CHECK( + STD_TORCH_CHECK( tensorContext->data.numel() * 2 <= MAX_TENSOR_SIZE, "We tried to allocate an output encoded tensor larger than ", MAX_TENSOR_SIZE, @@ -64,15 +65,15 @@ int write(void* opaque, const uint8_t* buf, int buf_size) { // We double the size of the outpout tensor. Calling cat() may not be the // most efficient, but it's simple. tensorContext->data = - torch::cat({tensorContext->data, tensorContext->data}); + stableCat({tensorContext->data, tensorContext->data}, 0); } - TORCH_CHECK( + STD_TORCH_CHECK( tensorContext->current_pos + bufSize <= tensorContext->data.numel(), "Re-allocation of the output tensor didn't work. ", "This should not happen, please report on TorchCodec bug tracker"); - uint8_t* outputTensorData = tensorContext->data.data_ptr(); + uint8_t* outputTensorData = tensorContext->data.mutable_data_ptr(); std::memcpy(outputTensorData + tensorContext->current_pos, buf, bufSize); tensorContext->current_pos += bufSize; // Track the maximum position written so getOutputTensor's narrow() does not @@ -104,18 +105,18 @@ int64_t seek(void* opaque, int64_t offset, int whence) { } // namespace -AVIOFromTensorContext::AVIOFromTensorContext(torch::Tensor data) +AVIOFromTensorContext::AVIOFromTensorContext(torch::stable::Tensor data) : tensorContext_{data, 0, 0} { - TORCH_CHECK(data.numel() > 0, "data must not be empty"); - TORCH_CHECK(data.is_contiguous(), "data must be contiguous"); - TORCH_CHECK(data.scalar_type() == torch::kUInt8, "data must be kUInt8"); + STD_TORCH_CHECK(data.numel() > 0, "data must not be empty"); + STD_TORCH_CHECK(data.is_contiguous(), "data must be contiguous"); + STD_TORCH_CHECK(data.scalar_type() == kStableUInt8, "data must be kUInt8"); createAVIOContext( &read, nullptr, &seek, &tensorContext_, /*isForWriting=*/false); } AVIOToTensorContext::AVIOToTensorContext() : tensorContext_{ - torch::empty({INITIAL_TENSOR_SIZE}, {torch::kUInt8}), + torch::stable::empty({INITIAL_TENSOR_SIZE}, kStableUInt8), 0, 0} { createAVIOContext( diff --git a/src/torchcodec/_core/AVIOTensorContext.h b/src/torchcodec/_core/AVIOTensorContext.h index bcd97052b..0a50856c6 100644 --- a/src/torchcodec/_core/AVIOTensorContext.h +++ b/src/torchcodec/_core/AVIOTensorContext.h @@ -6,15 +6,15 @@ #pragma once -#include -#include "src/torchcodec/_core/AVIOContextHolder.h" +#include "AVIOContextHolder.h" +#include "StableABICompat.h" namespace facebook::torchcodec { namespace detail { struct TensorContext { - torch::Tensor data; + torch::stable::Tensor data; int64_t current_pos; int64_t max_pos; }; @@ -23,19 +23,19 @@ struct TensorContext { // For Decoding: enables users to pass in the entire video or audio as bytes. // Our read and seek functions then traverse the bytes in memory. -class AVIOFromTensorContext : public AVIOContextHolder { +class FORCE_PUBLIC_VISIBILITY AVIOFromTensorContext : public AVIOContextHolder { public: - explicit AVIOFromTensorContext(torch::Tensor data); + explicit AVIOFromTensorContext(torch::stable::Tensor data); private: detail::TensorContext tensorContext_; }; // For Encoding: used to encode into an output uint8 (bytes) tensor. -class AVIOToTensorContext : public AVIOContextHolder { +class FORCE_PUBLIC_VISIBILITY AVIOToTensorContext : public AVIOContextHolder { public: explicit AVIOToTensorContext(); - torch::Tensor getOutputTensor(); + torch::stable::Tensor getOutputTensor(); private: detail::TensorContext tensorContext_; diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp index 587456f34..804c3ba78 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp @@ -4,20 +4,20 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. -#include -#include +#include #include #include +#include "StableABICompat.h" -#include "src/torchcodec/_core/BetaCudaDeviceInterface.h" +#include "BetaCudaDeviceInterface.h" -#include "src/torchcodec/_core/DeviceInterface.h" -#include "src/torchcodec/_core/FFMPEGCommon.h" -#include "src/torchcodec/_core/NVDECCache.h" +#include "DeviceInterface.h" +#include "FFMPEGCommon.h" +#include "NVDECCache.h" -#include "src/torchcodec/_core/NVCUVIDRuntimeLoader.h" -#include "src/torchcodec/_core/nvcuvid_include/cuviddec.h" -#include "src/torchcodec/_core/nvcuvid_include/nvcuvid.h" +#include "NVCUVIDRuntimeLoader.h" +#include "nvcuvid_include/cuviddec.h" +#include "nvcuvid_include/nvcuvid.h" extern "C" { #include @@ -28,9 +28,49 @@ namespace facebook::torchcodec { namespace { +// Per-device cache for cuvidGetDecoderCaps results. +// The key is a tuple of (device index, codec type, chroma format, bit depth +// minus 8). +struct DecoderCapsCache { + using Key = + std::tuple; + std::map cache; + std::mutex mutex; + + std::pair getDecoderCaps( + int deviceIndex, + cudaVideoCodec codecType, + cudaVideoChromaFormat chromaFormat, + unsigned int bitDepthMinus8) { + Key key{deviceIndex, codecType, chromaFormat, bitDepthMinus8}; + + std::lock_guard lock(mutex); + auto it = cache.find(key); + if (it != cache.end()) { + return {CUDA_SUCCESS, it->second}; + } + + CUVIDDECODECAPS caps = {}; + caps.eCodecType = codecType; + caps.eChromaFormat = chromaFormat; + caps.nBitDepthMinus8 = bitDepthMinus8; + + CUresult result = cuvidGetDecoderCaps(&caps); + if (result == CUDA_SUCCESS) { + cache[key] = caps; + } + return {result, caps}; + } +}; + +static DecoderCapsCache& getDecoderCapsCache() { + static DecoderCapsCache cache; + return cache; +} + static bool g_cuda_beta = registerDeviceInterface( - DeviceInterfaceKey(torch::kCUDA, /*variant=*/"beta"), - [](const torch::Device& device) { + DeviceInterfaceKey(kStableCUDA, /*variant=*/"beta"), + [](const StableDevice& device) { return new BetaCudaDeviceInterface(device); }); @@ -90,7 +130,7 @@ static UniqueCUvideodecoder createDecoder(CUVIDEOFORMAT* videoFormat) { CUvideodecoder* decoder = new CUvideodecoder(); CUresult result = cuvidCreateDecoder(decoder, &decoderParams); - TORCH_CHECK( + STD_TORCH_CHECK( result == CUDA_SUCCESS, "Failed to create NVDEC decoder: ", result); return UniqueCUvideodecoder(decoder, CUvideoDecoderDeleter{}); } @@ -99,7 +139,7 @@ std::optional validateChromaSupport( const AVPixFmtDescriptor* desc) { // Return the corresponding cudaVideoChromaFormat if supported, std::nullopt // otherwise. - TORCH_CHECK(desc != nullptr, "desc can't be null"); + STD_TORCH_CHECK(desc != nullptr, "desc can't be null"); if (desc->nb_components == 1) { return cudaVideoChromaFormat_Monochrome; @@ -152,7 +192,9 @@ std::optional validateCodecSupport(AVCodecID codecId) { } } -bool nativeNVDECSupport(const SharedAVCodecContext& codecContext) { +bool nativeNVDECSupport( + const StableDevice& device, + const SharedAVCodecContext& codecContext) { // Return true iff the input video stream is supported by our NVDEC // implementation. @@ -171,12 +213,12 @@ bool nativeNVDECSupport(const SharedAVCodecContext& codecContext) { return false; } - auto caps = CUVIDDECODECAPS{}; - caps.eCodecType = codecType.value(); - caps.eChromaFormat = chromaFormat.value(); - caps.nBitDepthMinus8 = desc->comp[0].depth - 8; - - CUresult result = cuvidGetDecoderCaps(&caps); + auto bitDepthMinus8 = static_cast(desc->comp[0].depth - 8); + auto [result, caps] = getDecoderCapsCache().getDecoderCaps( + getDeviceIndex(device), + codecType.value(), + chromaFormat.value(), + bitDepthMinus8); if (result != CUDA_SUCCESS) { return false; } @@ -221,11 +263,11 @@ void cudaBufferFreeCallback(void* opaque, [[maybe_unused]] uint8_t* data) { } // namespace -BetaCudaDeviceInterface::BetaCudaDeviceInterface(const torch::Device& device) +BetaCudaDeviceInterface::BetaCudaDeviceInterface(const StableDevice& device) : DeviceInterface(device) { - TORCH_CHECK(g_cuda_beta, "BetaCudaDeviceInterface was not registered!"); - TORCH_CHECK( - device_.type() == torch::kCUDA, "Unsupported device: ", device_.str()); + STD_TORCH_CHECK(g_cuda_beta, "BetaCudaDeviceInterface was not registered!"); + STD_TORCH_CHECK( + device_.type() == kStableCUDA, "Unsupported device: must be CUDA"); initializeCudaContextWithPytorch(device_); nppCtx_ = getNppStreamContext(device_); @@ -258,9 +300,11 @@ void BetaCudaDeviceInterface::initialize( const AVStream* avStream, const UniqueDecodingAVFormatContext& avFormatCtx, [[maybe_unused]] const SharedAVCodecContext& codecContext) { - if (!nvcuvidAvailable_ || !nativeNVDECSupport(codecContext)) { - cpuFallback_ = createDeviceInterface(torch::kCPU); - TORCH_CHECK( + STD_TORCH_CHECK(avStream != nullptr, "AVStream cannot be null"); + rotation_ = rotationFromDegrees(getRotationFromStream(avStream)); + if (!nvcuvidAvailable_ || !nativeNVDECSupport(device_, codecContext)) { + cpuFallback_ = createDeviceInterface(kStableCPU); + STD_TORCH_CHECK( cpuFallback_ != nullptr, "Failed to create CPU device interface"); cpuFallback_->initialize(avStream, avFormatCtx, codecContext); cpuFallback_->initializeVideo( @@ -271,19 +315,18 @@ void BetaCudaDeviceInterface::initialize( return; } - TORCH_CHECK(avStream != nullptr, "AVStream cannot be null"); timeBase_ = avStream->time_base; frameRateAvgFromFFmpeg_ = avStream->r_frame_rate; const AVCodecParameters* codecPar = avStream->codecpar; - TORCH_CHECK(codecPar != nullptr, "CodecParameters cannot be null"); + STD_TORCH_CHECK(codecPar != nullptr, "CodecParameters cannot be null"); initializeBSF(codecPar, avFormatCtx); // Create parser. Default values that aren't obvious are taken from DALI. CUVIDPARSERPARAMS parserParams = {}; auto codecType = validateCodecSupport(codecPar->codec_id); - TORCH_CHECK( + STD_TORCH_CHECK( codecType.has_value(), "This should never happen, we should be using the CPU fallback by now. Please report a bug."); parserParams.CodecType = codecType.value(); @@ -297,7 +340,7 @@ void BetaCudaDeviceInterface::initialize( parserParams.pfnDisplayPicture = pfnDisplayPictureCallback; CUresult result = cuvidCreateVideoParser(&videoParser_, &parserParams); - TORCH_CHECK( + STD_TORCH_CHECK( result == CUDA_SUCCESS, "Failed to create video parser: ", result); } @@ -308,9 +351,9 @@ void BetaCudaDeviceInterface::initializeBSF( // https://ffmpeg.org/doxygen/7.0/group__lavc__bsf.html // This is only needed for some formats, like H264 or HEVC. - TORCH_CHECK(codecPar != nullptr, "codecPar cannot be null"); - TORCH_CHECK(avFormatCtx != nullptr, "AVFormatContext cannot be null"); - TORCH_CHECK( + STD_TORCH_CHECK(codecPar != nullptr, "codecPar cannot be null"); + STD_TORCH_CHECK(avFormatCtx != nullptr, "AVFormatContext cannot be null"); + STD_TORCH_CHECK( avFormatCtx->iformat != nullptr, "AVFormatContext->iformat cannot be null"); std::string filterName; @@ -362,12 +405,12 @@ void BetaCudaDeviceInterface::initializeBSF( } const AVBitStreamFilter* avBSF = av_bsf_get_by_name(filterName.c_str()); - TORCH_CHECK( + STD_TORCH_CHECK( avBSF != nullptr, "Failed to find bitstream filter: ", filterName); AVBSFContext* avBSFContext = nullptr; int retVal = av_bsf_alloc(avBSF, &avBSFContext); - TORCH_CHECK( + STD_TORCH_CHECK( retVal >= AVSUCCESS, "Failed to allocate bitstream filter: ", getFFMPEGErrorStringFromErrorCode(retVal)); @@ -375,13 +418,13 @@ void BetaCudaDeviceInterface::initializeBSF( bitstreamFilter_.reset(avBSFContext); retVal = avcodec_parameters_copy(bitstreamFilter_->par_in, codecPar); - TORCH_CHECK( + STD_TORCH_CHECK( retVal >= AVSUCCESS, "Failed to copy codec parameters: ", getFFMPEGErrorStringFromErrorCode(retVal)); retVal = av_bsf_init(bitstreamFilter_.get()); - TORCH_CHECK( + STD_TORCH_CHECK( retVal == AVSUCCESS, "Failed to initialize bitstream filter: ", getFFMPEGErrorStringFromErrorCode(retVal)); @@ -395,7 +438,7 @@ void BetaCudaDeviceInterface::initializeBSF( // we should handle the case of multiple calls. Probably need to flush buffers, // etc. int BetaCudaDeviceInterface::streamPropertyChange(CUVIDEOFORMAT* videoFormat) { - TORCH_CHECK(videoFormat != nullptr, "Invalid video format"); + STD_TORCH_CHECK(videoFormat != nullptr, "Invalid video format"); videoFormat_ = *videoFormat; @@ -414,7 +457,7 @@ int BetaCudaDeviceInterface::streamPropertyChange(CUVIDEOFORMAT* videoFormat) { decoder_ = createDecoder(videoFormat); } - TORCH_CHECK(decoder_, "Failed to get or create decoder"); + STD_TORCH_CHECK(decoder_, "Failed to get or create decoder"); } // DALI also returns min_num_decode_surfaces from this function. This @@ -430,7 +473,7 @@ int BetaCudaDeviceInterface::sendPacket(ReferenceAVPacket& packet) { return cpuFallback_->sendPacket(packet); } - TORCH_CHECK( + STD_TORCH_CHECK( packet.get() && packet->data && packet->size > 0, "sendPacket received an empty packet, this is unexpected, please report."); @@ -478,7 +521,7 @@ ReferenceAVPacket& BetaCudaDeviceInterface::applyBSF( } int retVal = av_bsf_send_packet(bitstreamFilter_.get(), packet.get()); - TORCH_CHECK( + STD_TORCH_CHECK( retVal >= AVSUCCESS, "Failed to send packet to bitstream filter: ", getFFMPEGErrorStringFromErrorCode(retVal)); @@ -488,7 +531,7 @@ ReferenceAVPacket& BetaCudaDeviceInterface::applyBSF( // more than once. We should figure out whether that applies to the BSF we're // using. retVal = av_bsf_receive_packet(bitstreamFilter_.get(), filteredPacket.get()); - TORCH_CHECK( + STD_TORCH_CHECK( retVal >= AVSUCCESS, "Failed to receive packet from bitstream filter: ", getFFMPEGErrorStringFromErrorCode(retVal)); @@ -499,11 +542,19 @@ ReferenceAVPacket& BetaCudaDeviceInterface::applyBSF( // Parser triggers this callback within cuvidParseVideoData when a frame is // ready to be decoded, i.e. the parser received all the necessary packets for a // given frame. It means we can send that frame to be decoded by the hardware -// NVDEC decoder by calling cuvidDecodePicture which is non-blocking. +// NVDEC decoder by calling cuvidDecodePicture. int BetaCudaDeviceInterface::frameReadyForDecoding(CUVIDPICPARAMS* picParams) { - TORCH_CHECK(picParams != nullptr, "Invalid picture parameters"); - TORCH_CHECK(decoder_, "Decoder not initialized before picture decode"); - // Send frame to be decoded by NVDEC - non-blocking call. + STD_TORCH_CHECK(picParams != nullptr, "Invalid picture parameters"); + STD_TORCH_CHECK(decoder_, "Decoder not initialized before picture decode"); + // Send frame to be decoded by NVDEC. This may or may not block, depending on + // the internal state of the NVDEC. Presumably, when it blocks, it gets + // automatically unblocked once a frame has been decoded, although how and + // when it happens is unclear. The docs say: + // > cuvidDecodePicture() will stall if wait queue on NVDEC inside driver is + // full. + // and cuviddec.h says: + // > cuvidDecodePicture may block the calling thread if there are too many + // pictures pending in the decode queue. CUresult result = cuvidDecodePicture(*decoder_.get(), picParams); // Yes, you're reading that right, 0 means error, 1 means success @@ -539,8 +590,8 @@ int BetaCudaDeviceInterface::receiveFrame(UniqueAVFrame& avFrame) { // the NPP stream before any color conversion. // Re types: we get a cudaStream_t from PyTorch but it's interchangeable with // CUstream - procParams.output_stream = reinterpret_cast( - at::cuda::getCurrentCUDAStream(device_.index()).stream()); + procParams.output_stream = + reinterpret_cast(getCurrentCudaStream(device_.index())); CUdeviceptr framePtr = 0; unsigned int pitch = 0; @@ -579,7 +630,7 @@ void BetaCudaDeviceInterface::unmapPreviousFrame() { } CUresult result = cuvidUnmapVideoFrame(*decoder_.get(), previouslyMappedFrame_); - TORCH_CHECK( + STD_TORCH_CHECK( result == CUDA_SUCCESS, "Failed to unmap previous frame: ", result); previouslyMappedFrame_ = 0; } @@ -588,19 +639,19 @@ UniqueAVFrame BetaCudaDeviceInterface::convertCudaFrameToAVFrame( CUdeviceptr framePtr, unsigned int pitch, const CUVIDPARSERDISPINFO& dispInfo) { - TORCH_CHECK(framePtr != 0, "Invalid CUDA frame pointer"); + STD_TORCH_CHECK(framePtr != 0, "Invalid CUDA frame pointer"); // Get frame dimensions from video format display area (not coded dimensions) // This matches DALI's approach and avoids padding issues int width = videoFormat_.display_area.right - videoFormat_.display_area.left; int height = videoFormat_.display_area.bottom - videoFormat_.display_area.top; - TORCH_CHECK(width > 0 && height > 0, "Invalid frame dimensions"); - TORCH_CHECK( + STD_TORCH_CHECK(width > 0 && height > 0, "Invalid frame dimensions"); + STD_TORCH_CHECK( pitch >= static_cast(width), "Pitch must be >= width"); UniqueAVFrame avFrame(av_frame_alloc()); - TORCH_CHECK(avFrame.get() != nullptr, "Failed to allocate AVFrame"); + STD_TORCH_CHECK(avFrame.get() != nullptr, "Failed to allocate AVFrame"); avFrame->width = width; avFrame->height = height; @@ -631,6 +682,12 @@ UniqueAVFrame BetaCudaDeviceInterface::convertCudaFrameToAVFrame( case 6: avFrame->colorspace = AVCOL_SPC_SMPTE170M; // BT.601 break; + case 9: + avFrame->colorspace = AVCOL_SPC_BT2020_NCL; + break; + case 10: + avFrame->colorspace = AVCOL_SPC_BT2020_CL; + break; default: // Default to BT.601 avFrame->colorspace = AVCOL_SPC_SMPTE170M; @@ -685,36 +742,36 @@ UniqueAVFrame BetaCudaDeviceInterface::transferCpuFrameToGpuNV12( // - Then we allocate GPU memory and copy the NV12 CPU frame to the GPU. This // is what we return - TORCH_CHECK(cpuFrame != nullptr, "CPU frame cannot be null"); + STD_TORCH_CHECK(cpuFrame != nullptr, "CPU frame cannot be null"); int width = cpuFrame->width; int height = cpuFrame->height; // intermediate NV12 CPU frame. It's not on the GPU yet. UniqueAVFrame nv12CpuFrame(av_frame_alloc()); - TORCH_CHECK(nv12CpuFrame != nullptr, "Failed to allocate NV12 CPU frame"); + STD_TORCH_CHECK(nv12CpuFrame != nullptr, "Failed to allocate NV12 CPU frame"); nv12CpuFrame->format = AV_PIX_FMT_NV12; nv12CpuFrame->width = width; nv12CpuFrame->height = height; int ret = av_frame_get_buffer(nv12CpuFrame.get(), 0); - TORCH_CHECK( + STD_TORCH_CHECK( ret >= 0, "Failed to allocate NV12 CPU frame buffer: ", getFFMPEGErrorStringFromErrorCode(ret)); - SwsFrameContext swsFrameContext( + SwsConfig swsConfig( width, height, static_cast(cpuFrame->format), + cpuFrame->colorspace, width, height); - if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) { - swsContext_ = createSwsContext( - swsFrameContext, cpuFrame->colorspace, AV_PIX_FMT_NV12, SWS_BILINEAR); - prevSwsFrameContext_ = swsFrameContext; + if (!swsContext_ || prevSwsConfig_ != swsConfig) { + swsContext_ = createSwsContext(swsConfig, AV_PIX_FMT_NV12, SWS_BILINEAR); + prevSwsConfig_ = swsConfig; } int convertedHeight = sws_scale( @@ -725,11 +782,11 @@ UniqueAVFrame BetaCudaDeviceInterface::transferCpuFrameToGpuNV12( height, nv12CpuFrame->data, nv12CpuFrame->linesize); - TORCH_CHECK( + STD_TORCH_CHECK( convertedHeight == height, "sws_scale failed for CPU->NV12 conversion"); int ySize = width * height; - TORCH_CHECK( + STD_TORCH_CHECK( ySize % 2 == 0, "Y plane size must be even. Please report on TorchCodec repo."); int uvSize = ySize / 2; // NV12: UV plane is half the size of Y plane @@ -738,13 +795,13 @@ UniqueAVFrame BetaCudaDeviceInterface::transferCpuFrameToGpuNV12( uint8_t* cudaBuffer = nullptr; cudaError_t err = cudaMalloc(reinterpret_cast(&cudaBuffer), totalSize); - TORCH_CHECK( + STD_TORCH_CHECK( err == cudaSuccess, "Failed to allocate CUDA memory: ", cudaGetErrorString(err)); UniqueAVFrame gpuFrame(av_frame_alloc()); - TORCH_CHECK(gpuFrame != nullptr, "Failed to allocate GPU AVFrame"); + STD_TORCH_CHECK(gpuFrame != nullptr, "Failed to allocate GPU AVFrame"); gpuFrame->format = AV_PIX_FMT_CUDA; gpuFrame->width = width; @@ -765,12 +822,12 @@ UniqueAVFrame BetaCudaDeviceInterface::transferCpuFrameToGpuNV12( width, height, cudaMemcpyHostToDevice); - TORCH_CHECK( + STD_TORCH_CHECK( err == cudaSuccess, "Failed to copy Y plane to GPU: ", cudaGetErrorString(err)); - TORCH_CHECK( + STD_TORCH_CHECK( height % 2 == 0, "height must be even. Please report on TorchCodec repo."); err = cudaMemcpy2D( @@ -781,13 +838,13 @@ UniqueAVFrame BetaCudaDeviceInterface::transferCpuFrameToGpuNV12( width, height / 2, cudaMemcpyHostToDevice); - TORCH_CHECK( + STD_TORCH_CHECK( err == cudaSuccess, "Failed to copy UV plane to GPU: ", cudaGetErrorString(err)); ret = av_frame_copy_props(gpuFrame.get(), cpuFrame.get()); - TORCH_CHECK( + STD_TORCH_CHECK( ret >= 0, "Failed to copy frame properties: ", getFFMPEGErrorStringFromErrorCode(ret)); @@ -804,7 +861,7 @@ UniqueAVFrame BetaCudaDeviceInterface::transferCpuFrameToGpuNV12( cudaBufferFreeCallback, // callback triggered by av_frame_free() cudaBuffer, // parameter to callback 0); // flags - TORCH_CHECK( + STD_TORCH_CHECK( gpuFrame->opaque_ref != nullptr, "Failed to create GPU memory cleanup reference"); @@ -814,23 +871,65 @@ UniqueAVFrame BetaCudaDeviceInterface::transferCpuFrameToGpuNV12( void BetaCudaDeviceInterface::convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, - std::optional preAllocatedOutputTensor) { + std::optional preAllocatedOutputTensor) { UniqueAVFrame gpuFrame = cpuFallback_ ? transferCpuFrameToGpuNV12(avFrame) : std::move(avFrame); // TODONVDEC P2: we may need to handle 10bit videos the same way the CUDA // ffmpeg interface does it with maybeConvertAVFrameToNV12OrRGB24(). - TORCH_CHECK( + STD_TORCH_CHECK( gpuFrame->format == AV_PIX_FMT_CUDA, "Expected CUDA format frame from BETA CUDA interface"); - validatePreAllocatedTensorShape(preAllocatedOutputTensor, gpuFrame); + cudaStream_t nvdecStream = getCurrentCudaStream(device_.index()); - at::cuda::CUDAStream nvdecStream = - at::cuda::getCurrentCUDAStream(device_.index()); + if (rotation_ == Rotation::NONE) { + validatePreAllocatedTensorShape(preAllocatedOutputTensor, gpuFrame); + frameOutput.data = convertNV12FrameToRGB( + gpuFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor); + } else { + // preAllocatedOutputTensor has post-rotation dimensions, but NV12->RGB + // conversion outputs pre-rotation dimensions, so we can't use it as the + // conversion destination or validate it against the frame shape. + // Once we support native transforms on the beta CUDA interface, rotation + // should be handled as part of the transform pipeline instead. + frameOutput.data = convertNV12FrameToRGB( + gpuFrame, + device_, + nppCtx_, + nvdecStream, + /*preAllocatedOutputTensor=*/std::nullopt); + applyRotation(frameOutput, preAllocatedOutputTensor); + } +} - frameOutput.data = convertNV12FrameToRGB( - gpuFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor); +void BetaCudaDeviceInterface::applyRotation( + FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor) { + int k = 0; + switch (rotation_) { + case Rotation::CCW90: + k = 1; + break; + case Rotation::ROTATE180: + k = 2; + break; + case Rotation::CW90: + k = 3; + break; + default: + STD_TORCH_CHECK(false, "Unexpected rotation value"); + break; + } + // Apply rotation using rot90 on the H and W dims of our HWC tensor. + // stableRot90 returns a view, so we need to make it contiguous. + frameOutput.data = + torch::stable::contiguous(stableRot90(frameOutput.data, k, 0, 1)); + + if (preAllocatedOutputTensor.has_value()) { + torch::stable::copy_(preAllocatedOutputTensor.value(), frameOutput.data); + frameOutput.data = preAllocatedOutputTensor.value(); + } } std::string BetaCudaDeviceInterface::getDetails() { diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.h b/src/torchcodec/_core/BetaCudaDeviceInterface.h index 0b0e7e6c6..8f44dbda1 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.h +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.h @@ -15,27 +15,27 @@ #pragma once -#include "src/torchcodec/_core/CUDACommon.h" -#include "src/torchcodec/_core/Cache.h" -#include "src/torchcodec/_core/DeviceInterface.h" -#include "src/torchcodec/_core/FFMPEGCommon.h" -#include "src/torchcodec/_core/NVDECCache.h" +#include "CUDACommon.h" +#include "Cache.h" +#include "DeviceInterface.h" +#include "FFMPEGCommon.h" +#include "NVDECCache.h" +#include "Transform.h" -#include #include #include #include #include #include -#include "src/torchcodec/_core/nvcuvid_include/cuviddec.h" -#include "src/torchcodec/_core/nvcuvid_include/nvcuvid.h" +#include "nvcuvid_include/cuviddec.h" +#include "nvcuvid_include/nvcuvid.h" namespace facebook::torchcodec { class BetaCudaDeviceInterface : public DeviceInterface { public: - explicit BetaCudaDeviceInterface(const torch::Device& device); + explicit BetaCudaDeviceInterface(const StableDevice& device); virtual ~BetaCudaDeviceInterface(); void initialize( @@ -46,8 +46,7 @@ class BetaCudaDeviceInterface : public DeviceInterface { void convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, - std::optional preAllocatedOutputTensor = - std::nullopt) override; + std::optional preAllocatedOutputTensor) override; int sendPacket(ReferenceAVPacket& packet) override; int sendEOFPacket() override; @@ -83,6 +82,10 @@ class BetaCudaDeviceInterface : public DeviceInterface { UniqueAVFrame transferCpuFrameToGpuNV12(UniqueAVFrame& cpuFrame); + void applyRotation( + FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor); + CUvideoparser videoParser_ = nullptr; UniqueCUvideodecoder decoder_; CUVIDEOFORMAT videoFormat_ = {}; @@ -102,7 +105,9 @@ class BetaCudaDeviceInterface : public DeviceInterface { std::unique_ptr cpuFallback_; bool nvcuvidAvailable_ = false; UniqueSwsContext swsContext_; - SwsFrameContext prevSwsFrameContext_; + + SwsConfig prevSwsConfig_; + Rotation rotation_ = Rotation::NONE; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index f6a02596a..c040e2432 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.18) project(TorchCodec) -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(PYBIND11_FINDPYTHON ON) @@ -44,6 +44,8 @@ endif() if (WIN32) # Avoid warnings about non-ASCII characters in source files. set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4819") + # Required for Unicode support in fmt library (bundled with PyTorch). + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /utf-8") # Important for when we add Windows CUDA: exporting all symbols is limited to # 65535 symbols, which (apparently) will not work for CUDA. # https://github.com/pytorch/pytorch/pull/3650 @@ -64,7 +66,7 @@ function(make_torchcodec_sublibrary library_dependencies) add_library(${library_name} ${type} ${sources}) - set_target_properties(${library_name} PROPERTIES CXX_STANDARD 17) + set_target_properties(${library_name} PROPERTIES CXX_STANDARD 20) target_include_directories(${library_name} PRIVATE ./../../../ @@ -74,10 +76,47 @@ function(make_torchcodec_sublibrary "${TORCH_INSTALL_PREFIX}/include/paddle/phi/api/include/compat/torch/csrc/api/include" ${Python3_INCLUDE_DIRS} ) + # Use fmt library in header-only mode (bundled with PyTorch): the fmt + # library is bundled with PyTorch and the headers are available in + # TORCH_INSTALL_PREFIX/include/, which we add above. To be compatible with + # torch, we want to rely on the "header-only" mode of these headers and hit + # all the `#ifdef FMT_HEADER_ONLY` paths, so we define FMT_HEADER_ONLY here.``` + target_compile_definitions(${library_name} PRIVATE FMT_HEADER_ONLY) # Avoid adding the "lib" prefix which we already add explicitly. set_target_properties(${library_name} PROPERTIES PREFIX "") + # On macOS, for wheels, add Homebrew's FFmpeg library path to the rpath so + # that users with Homebrew-installed FFmpeg can use torchcodec without + # setting DYLD_LIBRARY_PATH. See + # https://github.com/pytorch/torchcodec/issues/570. + # That's the behavior we enable by default when building TorchCodec wheels + # which we'll ship on PyPI. + # Note that this means homebrew-FFmpeg will always take precedence over + # conda-FFmpeg. For this reason, this behavior can (and should!) be disabled + # when building TorchCodec for conda, by setting + # TORCHCODEC_DISABLE_HOMEBREW_RPATH. + # We should have the following behavior for users: + # - For TorchCodec installed from PyPI (pip or uv): Users don't have to use + # a conda env, and by default the homebrew-ffmpeg should be found. For + # users who do use a conda env, they should be able to use the + # conda-installed FFmpeg, but only if HomeBrew FFmpeg doesn't exist since + # homebrew will take precedence. + # - For TorchCodec installed from conda, users must install FFmpeg from + # conda as well (it should be a dependency of the TorchCodec conda + # package). Homebrew FFmpeg shouldn't be used. + if(APPLE AND NOT (DEFINED TORCHCODEC_DISABLE_HOMEBREW_RPATH AND TORCHCODEC_DISABLE_HOMEBREW_RPATH)) + if(DEFINED ENV{HOMEBREW_PREFIX}) + set(HOMEBREW_FFMPEG_LIB "$ENV{HOMEBREW_PREFIX}/opt/ffmpeg/lib") + else() + # Default Homebrew location on Apple Silicon + set(HOMEBREW_FFMPEG_LIB "/opt/homebrew/opt/ffmpeg/lib") + endif() + set_target_properties(${library_name} PROPERTIES + INSTALL_RPATH "${HOMEBREW_FFMPEG_LIB}" + ) + endif() + target_link_libraries( ${library_name} PUBLIC @@ -122,6 +161,10 @@ function(make_torchcodec_libraries Encoder.cpp ValidationUtils.cpp Transform.cpp + Metadata.cpp + SwScale.cpp + WavDecoder.cpp + NVDECCacheConfig.cpp ) if(ENABLE_CUDA) @@ -147,6 +190,15 @@ function(make_torchcodec_libraries "${core_library_dependencies}" ) + if(ENABLE_CUDA) + # We define USE_CUDA to guard CUDA-specific code paths (e.g. + # NVDECCache usage in NVDECCacheConfig.cpp) and because some torch + # APIs like aoti_torch_get_current_cuda_stream are only exposed when + # USE_CUDA is defined. + # https://github.com/pytorch/pytorch/blob/98e36864e640023a716e058d894ea2d20e76e5f7/torch/csrc/inductor/aoti_torch/c/shim.h#L573-L602 + target_compile_definitions(${core_library_name} PRIVATE USE_CUDA) + endif() + # 2. Create libtorchcodec_custom_opsN.{ext}. set(custom_ops_library_name "libtorchcodec_custom_ops${ffmpeg_major_version}") set(custom_ops_sources @@ -221,8 +273,36 @@ function(make_torchcodec_libraries ) endif() - # The value we use here must match the value we return from - # _get_pybind_ops_module_name() on the Python side. If the values do not + # We disable the "attributes" warning in the core library. + # This warning is triggered when, e.g., a class has "default" (i.e. + # public) visibility but a member of that class has "hidden" visibility. + # We have such a pattern in the core library with the VideoStreamOptions + # class and its StableDevice field: + # - VideoStreamOptions has "default" (public) visibility because it is part + # of the core library, and the core library must have public symbols + # because it is depended on by other libraries (the pybind ops library and + # the custom ops library) + # - the StableDevice field has "hidden" visibility because this is how + # torch::stable exports it. + # This creates this mismatch where a class has "higher" visibility than its + # member, hence the warning. + # + # We choose to silence this warning here. A possibly better solution would + # be to have a more fine-grained visibility control where each + # class/function would be "hidden" unless explicitly marked as public, e.g. + # through a TORCHCODEC_API annotation. In this case, it is likely that + # VideoStreamOptions could be in fact "hidden" as well. + # TODO_STABLE_ABI: do that! + if (LINUX) + target_compile_options( + ${core_library_name} + PRIVATE + "-Wno-attributes" + ) + endif() + + # The value we use here must match _PYBIND_OPS_MODULE_NAME in + # torchcodec/_internally_replaced_utils.py. If the values do not # match, then we will be unable to import the C++ shared library as a # Python module at runtime. target_compile_definitions( @@ -271,16 +351,16 @@ if(DEFINED ENV{BUILD_AGAINST_ALL_FFMPEG_FROM_S3}) you still need a different FFmpeg to be installed for run time!" ) - # This will expose the ffmpeg4, ffmpeg5, ffmpeg6, ffmpeg7, and ffmpeg8 targets + # This will expose the torchcodec::ffmpeg{N} (N=4,5,6,7,8) targets include( ${CMAKE_CURRENT_SOURCE_DIR}/fetch_and_expose_non_gpl_ffmpeg_libs.cmake ) - make_torchcodec_libraries(8 ffmpeg8) - make_torchcodec_libraries(7 ffmpeg7) - make_torchcodec_libraries(6 ffmpeg6) - make_torchcodec_libraries(4 ffmpeg4) - make_torchcodec_libraries(5 ffmpeg5) + make_torchcodec_libraries(8 torchcodec::ffmpeg8) + make_torchcodec_libraries(7 torchcodec::ffmpeg7) + make_torchcodec_libraries(6 torchcodec::ffmpeg6) + make_torchcodec_libraries(4 torchcodec::ffmpeg4) + make_torchcodec_libraries(5 torchcodec::ffmpeg5) else() message( STATUS @@ -289,40 +369,12 @@ else() installed FFmpeg from conda, make sure pkg-config is installed from conda as well." ) - find_package(PkgConfig REQUIRED) - pkg_check_modules(LIBAV REQUIRED IMPORTED_TARGET - libavdevice - libavfilter - libavformat - libavcodec - libavutil - libswresample - libswscale - ) - # Split libavcodec's version string by '.' and convert it to a list - string(REPLACE "." ";" libavcodec_version_list ${LIBAV_libavcodec_VERSION}) - # Get the first element of the list, which is the major version - list(GET libavcodec_version_list 0 libavcodec_major_version) - - if (${libavcodec_major_version} STREQUAL "58") - set(ffmpeg_major_version "4") - elseif (${libavcodec_major_version} STREQUAL "59") - set(ffmpeg_major_version "5") - elseif (${libavcodec_major_version} STREQUAL "60") - set(ffmpeg_major_version "6") - elseif (${libavcodec_major_version} STREQUAL "61") - set(ffmpeg_major_version "7") - elseif (${libavcodec_major_version} STREQUAL "62") - set(ffmpeg_major_version "8") - else() - message( - FATAL_ERROR - "Unsupported libavcodec version: ${libavcodec_major_version}" - ) - endif() + # This will expose `add_ffmpeg_target_with_pkg_config` + include("${CMAKE_CURRENT_SOURCE_DIR}/../share/cmake/TorchCodec/ffmpeg_versions.cmake") - make_torchcodec_libraries(${ffmpeg_major_version} PkgConfig::LIBAV) + add_ffmpeg_target_with_pkg_config(ffmpeg_major_version) + make_torchcodec_libraries(${ffmpeg_major_version} torchcodec::ffmpeg${ffmpeg_major_version}) # Expose these values updwards so that the test compilation does not need # to re-figure it out. FIXME: it's not great that we just copy-paste the diff --git a/src/torchcodec/_core/CUDACommon.cpp b/src/torchcodec/_core/CUDACommon.cpp index 4532e3c76..2dd11b171 100644 --- a/src/torchcodec/_core/CUDACommon.cpp +++ b/src/torchcodec/_core/CUDACommon.cpp @@ -4,8 +4,11 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. -#include "src/torchcodec/_core/CUDACommon.h" -#include "src/torchcodec/_core/Cache.h" // for PerGpuCache +#include "CUDACommon.h" +#include +#include "Cache.h" // for PerGpuCache +#include "StableABICompat.h" +#include "ValidationUtils.h" namespace facebook::torchcodec { @@ -21,12 +24,48 @@ PerGpuCache g_cached_npp_ctxs( } // namespace -void initializeCudaContextWithPytorch(const torch::Device& device) { +cudaStream_t getCurrentCudaStream(int32_t deviceIndex) { + // This is the documented and blessed way to get the current CUDA stream with + // the stable ABI. aoti_torch_get_current_cuda_stream, TORCH_ERROR_CODE_CHECK, + // and the corresponding torch/csrc/inductor/aoti_torch/c/shim.h header are + // all safe to use: + // https://github.com/pytorch/pytorch/blob/7bc8d4b0648e1d364dce0104c3aea2e7e3c1640a/docs/cpp/source/stable.rst?plain=1#L172-L179 + void* stream = nullptr; + TORCH_ERROR_CODE_CHECK( + aoti_torch_get_current_cuda_stream(deviceIndex, &stream)); + // Note: no need for checking against nullptr stream, it's a valid default + // stream value. + return static_cast(stream); +} + +// Make waitingStream wait until all work currently enqueued on runningStream +// has completed. +void syncStreams(cudaStream_t runningStream, cudaStream_t waitingStream) { + cudaEvent_t event; + cudaError_t err = cudaEventCreate(&event); + STD_TORCH_CHECK( + err == cudaSuccess, "cudaEventCreate failed: ", cudaGetErrorString(err)); + + err = cudaEventRecord(event, runningStream); + STD_TORCH_CHECK( + err == cudaSuccess, "cudaEventRecord failed: ", cudaGetErrorString(err)); + + err = cudaStreamWaitEvent(waitingStream, event, 0); + STD_TORCH_CHECK( + err == cudaSuccess, + "cudaStreamWaitEvent failed: ", + cudaGetErrorString(err)); + + cudaEventDestroy(event); +} + +void initializeCudaContextWithPytorch(const StableDevice& device) { // It is important for pytorch itself to create the cuda context. If ffmpeg // creates the context it may not be compatible with pytorch. // This is a dummy tensor to initialize the cuda context. - torch::Tensor dummyTensorForCudaInitialization = torch::zeros( - {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device)); + torch::stable::Tensor dummyTensorForCudaInitialization = torch::stable::empty( + {1}, kStableUInt8, std::nullopt, StableDevice(device)); + torch::stable::zero_(dummyTensorForCudaInitialization); } /* clang-format off */ @@ -146,24 +185,111 @@ void initializeCudaContextWithPytorch(const torch::Device& device) { // [ 1.0000e+00, -1.8732e-01, -4.6812e-01, -128] // [ 1.0000e+00, 1.8556e+00, 4.6231e-09 , -128]]) // -// And that's what we need to pass for BT701, full range. +// And that's what we need to pass for BT709, full range. /* clang-format on */ // BT.709 full range color conversion matrix for YUV to RGB conversion. // See Note [YUV -> RGB Color Conversion, color space and color range] +#if CUDART_VERSION >= 13000 const Npp32f bt709FullRangeColorTwist[3][4] = { {1.0f, 0.0f, 1.5748f, 0.0f}, {1.0f, -0.187324273f, -0.468124273f, -128.0f}, {1.0f, 1.8556f, 0.0f, -128.0f}}; +#else +// The note above about nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx +// offsets is actually only true for CUDA 13+. For CUDA 12, the behavior is +// different and still undocumented. +// See https://github.com/meta-pytorch/torchcodec/issues/1262#issue-3989049538 +// for how these offsets need to be derived. +const Npp32f bt709FullRangeColorTwist[3][4] = { + {1.0f, 0.0f, 1.5748f, -201.5744f}, + {1.0f, -0.187324273f, -0.468124273f, 83.8974f}, + {1.0f, 1.8556f, 0.0f, -237.5168f}}; +#endif -torch::Tensor convertNV12FrameToRGB( +// BT.601 full range color conversion matrix for YUV to RGB conversion. +// Luma coefficients for BT.601: kr=0.299, kg=0.587, kb=0.114 +// See Note [YUV -> RGB Color Conversion, color space and color range] +#if CUDART_VERSION >= 13000 +const Npp32f bt601FullRangeColorTwist[3][4] = { + {1.0f, 0.0f, 1.402f, 0.0f}, + {1.0f, -0.344136286f, -0.714136286f, -128.0f}, + {1.0f, 1.772f, 0.0f, -128.0f}}; +#else +const Npp32f bt601FullRangeColorTwist[3][4] = { + {1.0f, 0.0f, 1.402f, -179.456f}, + {1.0f, -0.344136286f, -0.714136286f, 135.4589f}, + {1.0f, 1.772f, 0.0f, -226.816f}}; +#endif + +// BT.601 limited range color conversion matrix for YUV to RGB conversion. +// Same luma coefficients as above, but Y is scaled from [16, 235] to [0, 255] +// (factor 255/219 = 1.16438356) and UV from [16, 240] (factor 255/224). +// NPP provides a pre-defined color conversion function for BT.601 limited +// range: nppiNV12ToRGB_8u_P2C3R_Ctx. But it's not closely matching the +// results we have on CPU. So we're using a custom color conversion matrix, +// which provides more accurate results. +// See Note [YUV -> RGB Color Conversion, color space and color range] +#if CUDART_VERSION >= 13000 +const Npp32f bt601LimitedRangeColorTwist[3][4] = { + {1.16438356f, 0.0f, 1.59602679f, -16.0f}, + {1.16438356f, -0.39176229f, -0.81296765f, -128.0f}, + {1.16438356f, 2.01723214f, 0.0f, -128.0f}}; +#else +const Npp32f bt601LimitedRangeColorTwist[3][4] = { + {1.16438356f, 0.0f, 1.59602679f, -222.9216f}, + {1.16438356f, -0.39176229f, -0.81296765f, 135.5753f}, + {1.16438356f, 2.01723214f, 0.0f, -276.8358f}}; +#endif + +// BT.2020 color conversion matrices for YUV to RGB conversion. +// BT.2020 uses Kr=0.2627, Kb=0.0593, Kg=1-Kr-Kb=0.6780 +// Derived the same way as BT.709 above (see the Note). +// The 3x3 coefficients come from inverting the RGB->YUV matrix: +// R = Y + 0*Cb + 1.4746*Cr +// G = Y - 0.164553*Cb - 0.571353*Cr +// B = Y + 1.8814*Cb + 0*Cr +// +// The 4th column (offset) depends on the CUDA version, because NPP changed +// the ColorTwist convention in CUDA 13. See the Note above and PR #1265. +// On CUDA >= 13: NPP internally centers U/V by subtracting 128. +// On CUDA < 13: NPP does NOT center U/V, so the offset must encode the full +// Cb/Cr centering contribution expanded into the constant term. +#if CUDART_VERSION >= 13000 +const Npp32f bt2020FullRangeColorTwist[3][4] = { + {1.0f, 0.0f, 1.4746f, 0.0f}, + {1.0f, -0.164553127f, -0.571353127f, -128.0f}, + {1.0f, 1.8814f, 0.0f, -128.0f}}; + +const Npp32f bt2020LimitedRangeColorTwist[3][4] = { + {1.16438356f, 0.0f, 1.67867411f, -16.0f}, + {1.16438356f, -0.187326105f, -0.650424319f, -128.0f}, + {1.16438356f, 2.14177232f, 0.0f, -128.0f}}; +#else +// CUDA < 13: expand Cb/Cr centering into the offset column. +// Full range offset_R = -(1.4746*128) = -188.7488 +// Full range offset_G = 0.164553127*128 + 0.571353127*128 = 94.196 +// Full range offset_B = -(1.8814*128) = -240.8192 +const Npp32f bt2020FullRangeColorTwist[3][4] = { + {1.0f, 0.0f, 1.4746f, -188.7488f}, + {1.0f, -0.164553127f, -0.571353127f, 94.196f}, + {1.0f, 1.8814f, 0.0f, -240.8192f}}; + +// Limited range: Y offset = -(1.16438356*16), plus Cb/Cr centering. +const Npp32f bt2020LimitedRangeColorTwist[3][4] = { + {1.16438356f, 0.0f, 1.67867411f, -233.5004f}, + {1.16438356f, -0.187326105f, -0.650424319f, 88.6019f}, + {1.16438356f, 2.14177232f, 0.0f, -292.7770f}}; +#endif + +torch::stable::Tensor convertNV12FrameToRGB( UniqueAVFrame& avFrame, - const torch::Device& device, + const StableDevice& device, const UniqueNppContext& nppCtx, - at::cuda::CUDAStream nvdecStream, - std::optional preAllocatedOutputTensor) { + cudaStream_t nvdecStream, + std::optional preAllocatedOutputTensor) { auto frameDims = FrameDims(avFrame->height, avFrame->width); - torch::Tensor dst; + torch::stable::Tensor dst; if (preAllocatedOutputTensor.has_value()) { dst = preAllocatedOutputTensor.value(); } else { @@ -173,15 +299,12 @@ torch::Tensor convertNV12FrameToRGB( // We need to make sure NVDEC has finished decoding a frame before // color-converting it with NPP. // So we make the NPP stream wait for NVDEC to finish. - at::cuda::CUDAStream nppStream = - at::cuda::getCurrentCUDAStream(device.index()); - at::cuda::CUDAEvent nvdecDoneEvent; - nvdecDoneEvent.record(nvdecStream); - nvdecDoneEvent.block(nppStream); + cudaStream_t nppStream = getCurrentCudaStream(device.index()); + syncStreams(/*runningStream=*/nvdecStream, /*waitingStream=*/nppStream); - nppCtx->hStream = nppStream.stream(); + nppCtx->hStream = nppStream; cudaError_t err = cudaStreamGetFlags(nppCtx->hStream, &nppCtx->nStreamFlags); - TORCH_CHECK( + STD_TORCH_CHECK( err == cudaSuccess, "cudaStreamGetFlags failed: ", cudaGetErrorString(err)); @@ -206,7 +329,7 @@ torch::Tensor convertNV12FrameToRGB( status = nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx( yuvData, srcStep, - static_cast(dst.data_ptr()), + dst.mutable_data_ptr(), dst.stride(0), oSizeROI, bt709FullRangeColorTwist, @@ -224,29 +347,67 @@ torch::Tensor convertNV12FrameToRGB( status = nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx( yuvData, avFrame->linesize[0], - static_cast(dst.data_ptr()), + dst.mutable_data_ptr(), dst.stride(0), oSizeROI, *nppCtx); } - } else { - // TODO we're assuming BT.601 color space (and probably limited range) by - // calling nppiNV12ToRGB_8u_P2C3R_Ctx. We should handle BT.601 full range, - // and other color-spaces like 2020. - status = nppiNV12ToRGB_8u_P2C3R_Ctx( + } else if ( + avFrame->colorspace == AVColorSpace::AVCOL_SPC_BT2020_NCL || + avFrame->colorspace == AVColorSpace::AVCOL_SPC_BT2020_CL) { + int srcStep[2] = {avFrame->linesize[0], avFrame->linesize[1]}; + + const Npp32f(*matrix)[4] = + (avFrame->color_range == AVColorRange::AVCOL_RANGE_JPEG) + ? bt2020FullRangeColorTwist + : bt2020LimitedRangeColorTwist; + + status = nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx( yuvData, - avFrame->linesize[0], - static_cast(dst.data_ptr()), + srcStep, + dst.mutable_data_ptr(), dst.stride(0), oSizeROI, + matrix, *nppCtx); + } else { + // When the colorspace is unspecified, we default to BT.601. This matches + // FFmpeg's swscale behavior: sws_getCoefficients(SWS_CS_DEFAULT) returns + // BT.601 coefficients + // https://github.com/FFmpeg/FFmpeg/blob/5b8a4a0e14cde74704b13493eb33cce3be260283/libswscale/swscale.h#L396-L403 + if (avFrame->color_range == AVColorRange::AVCOL_RANGE_JPEG) { + // BT.601 full range via custom color twist + int srcStep[2] = {avFrame->linesize[0], avFrame->linesize[1]}; + status = nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx( + yuvData, + srcStep, + dst.mutable_data_ptr(), + validateInt64ToInt(dst.stride(0), "dst.stride(0)"), + oSizeROI, + bt601FullRangeColorTwist, + *nppCtx); + } else { + // NPP provides a pre-defined color conversion function for BT.601 + // limited range: nppiNV12ToRGB_8u_P2C3R_Ctx. But it's not closely + // matching the results we have on CPU. So we're using a custom color + // conversion matrix, which provides more accurate results. + int srcStep[2] = {avFrame->linesize[0], avFrame->linesize[1]}; + status = nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx( + yuvData, + srcStep, + dst.mutable_data_ptr(), + validateInt64ToInt(dst.stride(0), "dst.stride(0)"), + oSizeROI, + bt601LimitedRangeColorTwist, + *nppCtx); + } } - TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame."); + STD_TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame."); return dst; } -UniqueNppContext getNppStreamContext(const torch::Device& device) { +UniqueNppContext getNppStreamContext(const StableDevice& device) { int deviceIndex = getDeviceIndex(device); UniqueNppContext nppCtx = g_cached_npp_ctxs.get(device); @@ -265,7 +426,7 @@ UniqueNppContext getNppStreamContext(const torch::Device& device) { nppCtx = std::make_unique(); cudaDeviceProp prop{}; cudaError_t err = cudaGetDeviceProperties(&prop, deviceIndex); - TORCH_CHECK( + STD_TORCH_CHECK( err == cudaSuccess, "cudaGetDeviceProperties failed: ", cudaGetErrorString(err)); @@ -282,7 +443,7 @@ UniqueNppContext getNppStreamContext(const torch::Device& device) { } void returnNppStreamContextToCache( - const torch::Device& device, + const StableDevice& device, UniqueNppContext nppCtx) { if (nppCtx) { g_cached_npp_ctxs.addIfCacheHasCapacity(device, std::move(nppCtx)); @@ -290,7 +451,7 @@ void returnNppStreamContextToCache( } void validatePreAllocatedTensorShape( - const std::optional& preAllocatedOutputTensor, + const std::optional& preAllocatedOutputTensor, const UniqueAVFrame& avFrame) { // Note that CUDA does not yet support transforms, so the only possible // frame dimensions are the raw decoded frame's dimensions. @@ -298,7 +459,7 @@ void validatePreAllocatedTensorShape( if (preAllocatedOutputTensor.has_value()) { auto shape = preAllocatedOutputTensor.value().sizes(); - TORCH_CHECK( + STD_TORCH_CHECK( (shape.size() == 3) && (shape[0] == frameDims.height) && (shape[1] == frameDims.width) && (shape[2] == 3), "Expected tensor of shape ", @@ -306,21 +467,21 @@ void validatePreAllocatedTensorShape( "x", frameDims.width, "x3, got ", - shape); + intArrayRefToString(shape)); } } -int getDeviceIndex(const torch::Device& device) { +int getDeviceIndex(const StableDevice& device) { // PyTorch uses int8_t as its torch::DeviceIndex, but FFmpeg and CUDA // libraries use int. So we use int, too. int deviceIndex = static_cast(device.index()); - TORCH_CHECK( + STD_TORCH_CHECK( deviceIndex >= -1 && deviceIndex < MAX_CUDA_GPUS, "Invalid device index = ", deviceIndex); if (deviceIndex == -1) { - TORCH_CHECK( + STD_TORCH_CHECK( cudaGetDevice(&deviceIndex) == cudaSuccess, "Failed to get current CUDA device."); } diff --git a/src/torchcodec/_core/CUDACommon.h b/src/torchcodec/_core/CUDACommon.h index 588f60e49..9e50a4e25 100644 --- a/src/torchcodec/_core/CUDACommon.h +++ b/src/torchcodec/_core/CUDACommon.h @@ -6,13 +6,11 @@ #pragma once -#include -#include +#include #include -#include -#include "src/torchcodec/_core/FFMPEGCommon.h" -#include "src/torchcodec/_core/Frame.h" +#include "FFMPEGCommon.h" +#include "Frame.h" extern "C" { #include @@ -25,27 +23,30 @@ namespace facebook::torchcodec { // https://github.com/pytorch/pytorch/blob/e30c55ee527b40d67555464b9e402b4b7ce03737/c10/cuda/CUDAMacros.h#L44 constexpr int MAX_CUDA_GPUS = 128; -void initializeCudaContextWithPytorch(const torch::Device& device); +cudaStream_t getCurrentCudaStream(int32_t deviceIndex); + +void initializeCudaContextWithPytorch(const StableDevice& device); // Unique pointer type for NPP stream context using UniqueNppContext = std::unique_ptr; -torch::Tensor convertNV12FrameToRGB( +torch::stable::Tensor convertNV12FrameToRGB( UniqueAVFrame& avFrame, - const torch::Device& device, + const StableDevice& device, const UniqueNppContext& nppCtx, - at::cuda::CUDAStream nvdecStream, - std::optional preAllocatedOutputTensor = std::nullopt); + cudaStream_t nvdecStream, + std::optional preAllocatedOutputTensor = + std::nullopt); -UniqueNppContext getNppStreamContext(const torch::Device& device); +UniqueNppContext getNppStreamContext(const StableDevice& device); void returnNppStreamContextToCache( - const torch::Device& device, + const StableDevice& device, UniqueNppContext nppCtx); void validatePreAllocatedTensorShape( - const std::optional& preAllocatedOutputTensor, + const std::optional& preAllocatedOutputTensor, const UniqueAVFrame& avFrame); -int getDeviceIndex(const torch::Device& device); +int getDeviceIndex(const StableDevice& device); } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/Cache.h b/src/torchcodec/_core/Cache.h index b2c93e8ea..77bd6c51e 100644 --- a/src/torchcodec/_core/Cache.h +++ b/src/torchcodec/_core/Cache.h @@ -6,9 +6,9 @@ #pragma once -#include #include #include +#include "StableABICompat.h" namespace facebook::torchcodec { @@ -75,7 +75,7 @@ class PerGpuCache { // Initializes 'maxGpus' number of caches. Each cache can hold no // more than 'capacity' items. If 'capacity' <0 cache size is unlimited. PerGpuCache(int maxGpus, int capacity) { - TORCH_CHECK(maxGpus > 0, "maxGpus for PerGpuCache must be >0"); + STD_TORCH_CHECK(maxGpus > 0, "maxGpus for PerGpuCache must be >0"); for (int i = 0; i < maxGpus; ++i) { cache_.emplace_back(std::make_unique>(capacity)); } @@ -83,11 +83,11 @@ class PerGpuCache { // Adds an object to the specified device cache if the cache has // capacity. Returns true if object was added and false otherwise. - bool addIfCacheHasCapacity(const torch::Device& device, element_type&& obj); + bool addIfCacheHasCapacity(const StableDevice& device, element_type&& obj); // Returns an object from the cache of the specified device. Cache // does not hold a reference to the object after this call. - element_type get(const torch::Device& device); + element_type get(const StableDevice& device); private: // 'Cache' class implementation contains mutex which makes it non-movable @@ -98,14 +98,14 @@ class PerGpuCache { // Forward declaration of getDeviceIndex which exists in CUDACommon.h // This avoids circular dependency between Cache.h and CUDACommon.cpp which also // needs to include Cache.h -int getDeviceIndex(const torch::Device& device); +int getDeviceIndex(const StableDevice& device); template bool PerGpuCache::addIfCacheHasCapacity( - const torch::Device& device, + const StableDevice& device, element_type&& obj) { int deviceIndex = getDeviceIndex(device); - TORCH_CHECK( + STD_TORCH_CHECK( static_cast(deviceIndex) < cache_.size(), "Device index out of range"); return cache_[deviceIndex]->addIfCacheHasCapacity(std::move(obj)); @@ -113,9 +113,9 @@ bool PerGpuCache::addIfCacheHasCapacity( template typename PerGpuCache::element_type PerGpuCache::get( - const torch::Device& device) { + const StableDevice& device) { int deviceIndex = getDeviceIndex(device); - TORCH_CHECK( + STD_TORCH_CHECK( static_cast(deviceIndex) < cache_.size(), "Device index out of range"); return cache_[deviceIndex]->get(); diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index bb0988a13..d0925fc21 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -4,29 +4,29 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. -#include "src/torchcodec/_core/CpuDeviceInterface.h" +#include "CpuDeviceInterface.h" namespace facebook::torchcodec { namespace { static bool g_cpu = registerDeviceInterface( - DeviceInterfaceKey(torch::kCPU), - [](const torch::Device& device) { return new CpuDeviceInterface(device); }); + DeviceInterfaceKey(kStableCPU), + [](const StableDevice& device) { return new CpuDeviceInterface(device); }); } // namespace -CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device) +CpuDeviceInterface::CpuDeviceInterface(const StableDevice& device) : DeviceInterface(device) { - TORCH_CHECK(g_cpu, "CpuDeviceInterface was not registered!"); - TORCH_CHECK( - device_.type() == torch::kCPU, "Unsupported device: ", device_.str()); + STD_TORCH_CHECK(g_cpu, "CpuDeviceInterface was not registered!"); + STD_TORCH_CHECK( + device_.type() == kStableCPU, "Unsupported device: must be CPU"); } void CpuDeviceInterface::initialize( const AVStream* avStream, [[maybe_unused]] const UniqueDecodingAVFormatContext& avFormatCtx, const SharedAVCodecContext& codecContext) { - TORCH_CHECK(avStream != nullptr, "avStream is null"); + STD_TORCH_CHECK(avStream != nullptr, "avStream is null"); codecContext_ = codecContext; timeBase_ = avStream->time_base; } @@ -35,16 +35,19 @@ void CpuDeviceInterface::initializeVideo( const VideoStreamOptions& videoStreamOptions, const std::vector>& transforms, const std::optional& resizedOutputDims) { + avMediaType_ = AVMEDIA_TYPE_VIDEO; videoStreamOptions_ = videoStreamOptions; resizedOutputDims_ = resizedOutputDims; - // We can only use swscale when we have a single resize transform. Note that - // this means swscale will not support the case of having several, - // back-to-base resizes. There's no strong reason to even do that, but if - // someone does, it's more correct to implement that with filtergraph. + // We can use swscale when we have a single resize transform. + // With a single resize, we use swscale twice: + // first for color conversion (YUV->RGB24), then for resize in RGB24 space. + // + // Note that this means swscale will not support the case of having several, + // back-to-back resizes or other transforms. // - // We calculate this value during initilization but we don't refer to it until - // getColorConversionLibrary() is called. Calculating this value during + // We calculate this value during initialization but we don't refer to it + // until getColorConversionLibrary() is called. Calculating this value during // initialization saves us from having to save all of the transforms. areTransformsSwScaleCompatible_ = transforms.empty() || (transforms.size() == 1 && transforms[0]->isResize()); @@ -63,7 +66,8 @@ void CpuDeviceInterface::initializeVideo( // need to know the actual frame dimensions. if (transforms.size() == 1 && transforms[0]->isResize()) { auto resize = dynamic_cast(transforms[0].get()); - TORCH_CHECK(resize != nullptr, "ResizeTransform expected but not found!") + STD_TORCH_CHECK( + resize != nullptr, "ResizeTransform expected but not found!"); swsFlags_ = resize->getSwsFlags(); } @@ -80,17 +84,37 @@ void CpuDeviceInterface::initializeVideo( first = false; } if (!transforms.empty()) { - filters_ = filters.str(); + // Note [Transform and Format Conversion Order] + // We have to ensure that all user filters happen AFTER the explicit format + // conversion. That is, we want the filters to be applied in RGB24, not the + // pixel format of the input frame. + // + // The ouput frame will always be in RGB24, as we specify the sink node with + // AV_PIX_FORMAT_RGB24. Filtergraph will automatically insert a filter + // conversion to ensure the output frame matches the pixel format + // specified in the sink. But by default, it will insert it after the user + // filters. We need an explicit format conversion to get the behavior we + // want. + filters_ = "format=rgb24," + filters.str(); } initialized_ = true; } +void CpuDeviceInterface::initializeAudio( + const AudioStreamOptions& audioStreamOptions) { + avMediaType_ = AVMEDIA_TYPE_AUDIO; + audioStreamOptions_ = audioStreamOptions; + initialized_ = true; +} + ColorConversionLibrary CpuDeviceInterface::getColorConversionLibrary( + const FrameDims& inputDims, const FrameDims& outputDims) const { // swscale requires widths to be multiples of 32: // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements - bool isWidthSwScaleCompatible = (outputDims.width % 32) == 0; + bool areWidthsSwScaleCompatible = + (inputDims.width % 32) == 0 && (outputDims.width % 32) == 0; // We want to use swscale for color conversion if possible because it is // faster than filtergraph. The following are the conditions we need to meet @@ -107,28 +131,40 @@ ColorConversionLibrary CpuDeviceInterface::getColorConversionLibrary( // filtergraph in our public API, this is probably okay. It's also the only // way that we can be certain we are testing one versus the other. if (areTransformsSwScaleCompatible_ && - (userRequestedSwScale_ || isWidthSwScaleCompatible)) { + (userRequestedSwScale_ || areWidthsSwScaleCompatible)) { return ColorConversionLibrary::SWSCALE; } else { return ColorConversionLibrary::FILTERGRAPH; } } +void CpuDeviceInterface::convertAVFrameToFrameOutput( + UniqueAVFrame& avFrame, + FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor) { + STD_TORCH_CHECK(initialized_, "CpuDeviceInterface was not initialized."); + + if (avMediaType_ == AVMEDIA_TYPE_AUDIO) { + convertAudioAVFrameToFrameOutput(avFrame, frameOutput); + } else { + convertVideoAVFrameToFrameOutput( + avFrame, frameOutput, preAllocatedOutputTensor); + } +} + // Note [preAllocatedOutputTensor with swscale and filtergraph]: // Callers may pass a pre-allocated tensor, where the output.data tensor will // be stored. This parameter is honored in any case, but it only leads to a // speed-up when swscale is used. With swscale, we can tell ffmpeg to place the -// decoded frame directly into `preAllocatedtensor.data_ptr()`. We haven't yet -// found a way to do that with filtegraph. +// decoded frame directly into `preAllocatedtensor.mutable_data_ptr()`. We +// haven't yet found a way to do that with filtegraph. // TODO: Figure out whether that's possible! // Dimension order of the preAllocatedOutputTensor must be HWC, regardless of // `dimension_order` parameter. It's up to callers to re-shape it if needed. -void CpuDeviceInterface::convertAVFrameToFrameOutput( +void CpuDeviceInterface::convertVideoAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, - std::optional preAllocatedOutputTensor) { - TORCH_CHECK(initialized_, "CpuDeviceInterface was not initialized."); - + std::optional preAllocatedOutputTensor) { // Note that we ignore the dimensions from the metadata; we don't even bother // storing them. The resized dimensions take priority. If we don't have any, // then we use the dimensions from the actual decoded frame. We use the actual @@ -142,12 +178,12 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( // Both cases cause problems for our batch APIs, as we allocate // FrameBatchOutputs based on the the stream metadata. But single-frame APIs // can still work in such situations, so they should. - auto outputDims = - resizedOutputDims_.value_or(FrameDims(avFrame->height, avFrame->width)); + auto inputDims = FrameDims(avFrame->height, avFrame->width); + auto outputDims = resizedOutputDims_.value_or(inputDims); if (preAllocatedOutputTensor.has_value()) { auto shape = preAllocatedOutputTensor.value().sizes(); - TORCH_CHECK( + STD_TORCH_CHECK( (shape.size() == 3) && (shape[0] == outputDims.height) && (shape[1] == outputDims.width) && (shape[2] == 3), "Expected pre-allocated tensor of shape ", @@ -155,23 +191,38 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( "x", outputDims.width, "x3, got ", - shape); + intArrayRefToString(shape)); } - auto colorConversionLibrary = getColorConversionLibrary(outputDims); - torch::Tensor outputTensor; + auto colorConversionLibrary = + getColorConversionLibrary(inputDims, outputDims); + torch::stable::Tensor outputTensor; if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) { outputTensor = preAllocatedOutputTensor.value_or( - allocateEmptyHWCTensor(outputDims, torch::kCPU)); + allocateEmptyHWCTensor(outputDims, kStableCPU)); - int resultHeight = - convertAVFrameToTensorUsingSwScale(avFrame, outputTensor, outputDims); + enum AVPixelFormat avFrameFormat = + static_cast(avFrame->format); + + SwsConfig swsConfig( + avFrame->width, + avFrame->height, + avFrameFormat, + avFrame->colorspace, + outputDims.width, + outputDims.height); + + if (!swScale_ || swScale_->getConfig() != swsConfig) { + swScale_ = std::make_unique(swsConfig, swsFlags_); + } + + int resultHeight = swScale_->convert(avFrame, outputTensor); // If this check failed, it would mean that the frame wasn't reshaped to // the expected height. // TODO: Can we do the same check for width? - TORCH_CHECK( + STD_TORCH_CHECK( resultHeight == outputDims.height, "resultHeight != outputDims.height: ", resultHeight, @@ -185,7 +236,7 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( // Similarly to above, if this check fails it means the frame wasn't // reshaped to its expected dimensions by filtergraph. auto shape = outputTensor.sizes(); - TORCH_CHECK( + STD_TORCH_CHECK( (shape.size() == 3) && (shape[0] == outputDims.height) && (shape[1] == outputDims.width) && (shape[2] == 3), "Expected output tensor of shape ", @@ -193,93 +244,249 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( "x", outputDims.width, "x3, got ", - shape); + intArrayRefToString(shape)); if (preAllocatedOutputTensor.has_value()) { // We have already validated that preAllocatedOutputTensor and // outputTensor have the same shape. - preAllocatedOutputTensor.value().copy_(outputTensor); + torch::stable::copy_(preAllocatedOutputTensor.value(), outputTensor); frameOutput.data = preAllocatedOutputTensor.value(); } else { frameOutput.data = outputTensor; } } else { - TORCH_CHECK( + STD_TORCH_CHECK( false, "Invalid color conversion library: ", static_cast(colorConversionLibrary)); } } -int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale( +torch::stable::Tensor +CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph( const UniqueAVFrame& avFrame, - torch::Tensor& outputTensor, const FrameDims& outputDims) { - enum AVPixelFormat frameFormat = + enum AVPixelFormat avFrameFormat = static_cast(avFrame->format); - // We need to compare the current frame context with our previous frame - // context. If they are different, then we need to re-create our colorspace - // conversion objects. We create our colorspace conversion objects late so - // that we don't have to depend on the unreliable metadata in the header. - // And we sometimes re-create them because it's possible for frame - // resolution to change mid-stream. Finally, we want to reuse the colorspace - // conversion objects as much as possible for performance reasons. - SwsFrameContext swsFrameContext( + FiltersConfig filtersConfig( avFrame->width, avFrame->height, - frameFormat, - outputDims.width, - outputDims.height); - - if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) { - swsContext_ = createSwsContext( - swsFrameContext, avFrame->colorspace, AV_PIX_FMT_RGB24, swsFlags_); - prevSwsFrameContext_ = swsFrameContext; - } - - uint8_t* pointers[4] = { - outputTensor.data_ptr(), nullptr, nullptr, nullptr}; - int expectedOutputWidth = outputTensor.sizes()[1]; - int linesizes[4] = {expectedOutputWidth * 3, 0, 0, 0}; - int resultHeight = sws_scale( - swsContext_.get(), - avFrame->data, - avFrame->linesize, - 0, - avFrame->height, - pointers, - linesizes); - return resultHeight; -} - -torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph( - const UniqueAVFrame& avFrame, - const FrameDims& outputDims) { - enum AVPixelFormat frameFormat = - static_cast(avFrame->format); - - FiltersContext filtersContext( - avFrame->width, - avFrame->height, - frameFormat, + avFrameFormat, avFrame->sample_aspect_ratio, outputDims.width, outputDims.height, - AV_PIX_FMT_RGB24, + /*outputFormat=*/AV_PIX_FMT_RGB24, filters_, timeBase_); - if (!filterGraph_ || prevFiltersContext_ != filtersContext) { + if (!filterGraph_ || prevFiltersConfig_ != filtersConfig) { filterGraph_ = - std::make_unique(filtersContext, videoStreamOptions_); - prevFiltersContext_ = std::move(filtersContext); + std::make_unique(filtersConfig, videoStreamOptions_); + prevFiltersConfig_ = std::move(filtersConfig); } return rgbAVFrameToTensor(filterGraph_->convert(avFrame)); } +void CpuDeviceInterface::convertAudioAVFrameToFrameOutput( + UniqueAVFrame& srcAVFrame, + FrameOutput& frameOutput) { + AVSampleFormat srcSampleFormat = + static_cast(srcAVFrame->format); + AVSampleFormat outSampleFormat = AV_SAMPLE_FMT_FLTP; + + int srcSampleRate = srcAVFrame->sample_rate; + int outSampleRate = audioStreamOptions_.sampleRate.value_or(srcSampleRate); + + int srcNumChannels = getNumChannels(codecContext_); + STD_TORCH_CHECK( + srcNumChannels == getNumChannels(srcAVFrame), + "The frame has ", + getNumChannels(srcAVFrame), + " channels, expected ", + srcNumChannels, + ". If you are hitting this, it may be because you are using " + "a buggy FFmpeg version. FFmpeg4 is known to fail here in some " + "valid scenarios. Try to upgrade FFmpeg?"); + int outNumChannels = audioStreamOptions_.numChannels.value_or(srcNumChannels); + + bool mustConvert = + (srcSampleFormat != outSampleFormat || srcSampleRate != outSampleRate || + srcNumChannels != outNumChannels); + + UniqueAVFrame convertedAVFrame; + if (mustConvert) { + if (!swrContext_) { + swrContext_.reset(createSwrContext( + srcSampleFormat, + outSampleFormat, + srcSampleRate, + outSampleRate, + srcAVFrame, + outNumChannels)); + } + + convertedAVFrame = convertAudioAVFrameSamples( + swrContext_, + srcAVFrame, + outSampleFormat, + outSampleRate, + outNumChannels); + } + const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame; + + AVSampleFormat format = static_cast(avFrame->format); + STD_TORCH_CHECK( + format == outSampleFormat, + "Something went wrong, the frame didn't get converted to the desired format. ", + "Desired format = ", + av_get_sample_fmt_name(outSampleFormat), + "source format = ", + av_get_sample_fmt_name(format)); + + int numChannels = getNumChannels(avFrame); + STD_TORCH_CHECK( + numChannels == outNumChannels, + "Something went wrong, the frame didn't get converted to the desired ", + "number of channels = ", + outNumChannels, + ". Got ", + numChannels, + " instead."); + + auto numSamples = avFrame->nb_samples; + + frameOutput.data = torch::stable::empty({numChannels, numSamples}); + + if (numSamples > 0) { + uint8_t* outputChannelData = + static_cast(frameOutput.data.mutable_data_ptr()); + auto numBytesPerChannel = numSamples * av_get_bytes_per_sample(format); + for (auto channel = 0; channel < numChannels; + ++channel, outputChannelData += numBytesPerChannel) { + std::memcpy( + outputChannelData, + avFrame->extended_data[channel], + numBytesPerChannel); + } + } +} + +std::optional +CpuDeviceInterface::maybeFlushAudioBuffers() { + // When sample rate conversion is involved, swresample buffers some of the + // samples in-between calls to swr_convert (see the libswresample docs). + // That's because the last few samples in a given frame require future + // samples from the next frame to be properly converted. This function + // flushes out the samples that are stored in swresample's buffers. + if (!swrContext_) { + return std::nullopt; + } + auto numRemainingSamples = // this is an upper bound + swr_get_out_samples(swrContext_.get(), 0); + + if (numRemainingSamples == 0) { + return std::nullopt; + } + + int numChannels = + audioStreamOptions_.numChannels.value_or(getNumChannels(codecContext_)); + torch::stable::Tensor lastSamples = + torch::stable::empty({numChannels, numRemainingSamples}); + + std::vector outputBuffers(numChannels); + for (auto i = 0; i < numChannels; i++) { + outputBuffers[i] = reinterpret_cast( + selectRow(lastSamples, i).mutable_data_ptr()); + } + + auto actualNumRemainingSamples = swr_convert( + swrContext_.get(), outputBuffers.data(), numRemainingSamples, nullptr, 0); + + return torch::stable::narrow( + lastSamples, + /*dim=*/1, + /*start=*/0, + /*length=*/actualNumRemainingSamples); +} + std::string CpuDeviceInterface::getDetails() { return std::string("CPU Device Interface."); } +UniqueAVFrame CpuDeviceInterface::convertTensorToAVFrameForEncoding( + const torch::stable::Tensor& frame, + int frameIndex, + AVCodecContext* codecContext) { + int inHeight = static_cast(frame.sizes()[1]); + int inWidth = static_cast(frame.sizes()[2]); + AVPixelFormat inPixelFormat = AV_PIX_FMT_GBRP; + int outWidth = codecContext->width; + int outHeight = codecContext->height; + AVPixelFormat outPixelFormat = codecContext->pix_fmt; + + // Initialize and cache scaling context if it does not exist + if (!encodingSwsContext_) { + encodingSwsContext_.reset(sws_getContext( + inWidth, + inHeight, + inPixelFormat, + outWidth, + outHeight, + outPixelFormat, + SWS_BICUBIC, // Used by FFmpeg CLI + nullptr, + nullptr, + nullptr)); + STD_TORCH_CHECK( + encodingSwsContext_ != nullptr, "Failed to create scaling context"); + } + + UniqueAVFrame avFrame(av_frame_alloc()); + STD_TORCH_CHECK(avFrame != nullptr, "Failed to allocate AVFrame"); + + // Set output frame properties + avFrame->format = outPixelFormat; + avFrame->width = outWidth; + avFrame->height = outHeight; + avFrame->pts = frameIndex; + + int status = av_frame_get_buffer(avFrame.get(), 0); + STD_TORCH_CHECK(status >= 0, "Failed to allocate frame buffer"); + + // Need to convert/scale the frame + // Create temporary frame with input format + UniqueAVFrame inputFrame(av_frame_alloc()); + STD_TORCH_CHECK(inputFrame != nullptr, "Failed to allocate input AVFrame"); + + inputFrame->format = inPixelFormat; + inputFrame->width = inWidth; + inputFrame->height = inHeight; + + uint8_t* tensorData = static_cast(frame.mutable_data_ptr()); + + int channelSize = inHeight * inWidth; + // Since frames tensor is in NCHW, we must use a planar format. + // FFmpeg only provides AV_PIX_FMT_GBRP for planar RGB, + // so we reorder RGB -> GBR. + inputFrame->data[0] = tensorData + channelSize; + inputFrame->data[1] = tensorData + (2 * channelSize); + inputFrame->data[2] = tensorData; + + inputFrame->linesize[0] = inWidth; + inputFrame->linesize[1] = inWidth; + inputFrame->linesize[2] = inWidth; + + status = sws_scale( + encodingSwsContext_.get(), + inputFrame->data, + inputFrame->linesize, + 0, + inputFrame->height, + avFrame->data, + avFrame->linesize); + STD_TORCH_CHECK(status == outHeight, "sws_scale failed"); + return avFrame; +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index f7c57045a..7cec64a43 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -6,20 +6,22 @@ #pragma once -#include "src/torchcodec/_core/DeviceInterface.h" -#include "src/torchcodec/_core/FFMPEGCommon.h" -#include "src/torchcodec/_core/FilterGraph.h" +#include "DeviceInterface.h" +#include "FFMPEGCommon.h" +#include "FilterGraph.h" +#include "SwScale.h" namespace facebook::torchcodec { class CpuDeviceInterface : public DeviceInterface { public: - CpuDeviceInterface(const torch::Device& device); + CpuDeviceInterface(const StableDevice& device); virtual ~CpuDeviceInterface() {} std::optional findCodec( - [[maybe_unused]] const AVCodecID& codecId) override { + [[maybe_unused]] const AVCodecID& codecId, + [[maybe_unused]] bool isDecoder = true) override { return std::nullopt; } @@ -33,26 +35,41 @@ class CpuDeviceInterface : public DeviceInterface { const std::vector>& transforms, const std::optional& resizedOutputDims) override; + virtual void initializeAudio( + const AudioStreamOptions& audioStreamOptions) override; + + virtual std::optional maybeFlushAudioBuffers() + override; + void convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, - std::optional preAllocatedOutputTensor = - std::nullopt) override; + std::optional preAllocatedOutputTensor) override; + + UniqueAVFrame convertTensorToAVFrameForEncoding( + const torch::stable::Tensor& tensor, + int frameIndex, + AVCodecContext* codecContext) override; std::string getDetails() override; private: - int convertAVFrameToTensorUsingSwScale( - const UniqueAVFrame& avFrame, - torch::Tensor& outputTensor, - const FrameDims& outputDims); + void convertAudioAVFrameToFrameOutput( + UniqueAVFrame& srcAVFrame, + FrameOutput& frameOutput); - torch::Tensor convertAVFrameToTensorUsingFilterGraph( + void convertVideoAVFrameToFrameOutput( + UniqueAVFrame& avFrame, + FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor); + + torch::stable::Tensor convertAVFrameToTensorUsingFilterGraph( const UniqueAVFrame& avFrame, const FrameDims& outputDims); ColorConversionLibrary getColorConversionLibrary( - const FrameDims& inputFrameDims) const; + const FrameDims& inputDims, + const FrameDims& outputDims) const; VideoStreamOptions videoStreamOptions_; AVRational timeBase_; @@ -66,48 +83,54 @@ class CpuDeviceInterface : public DeviceInterface { // resolutions. std::optional resizedOutputDims_; - // Color-conversion objects. Only one of filterGraph_ and swsContext_ should + // Color-conversion objects. Only one of filterGraph_ and swScale_ should // be non-null. Which one we use is determined dynamically in // getColorConversionLibrary() each time we decode a frame. // - // Creating both filterGraph_ and swsContext_ is relatively expensive, so we - // reuse them across frames. However, it is possbile that subsequent frames + // Creating both filterGraph_ and swScale_ is relatively expensive, so we + // reuse them across frames. However, it is possible that subsequent frames // are different enough (change in dimensions) that we can't reuse the color - // conversion object. We store the relevant frame context from the frame used + // conversion object. We store the relevant frame config from the frame used // to create the object last time. We always compare the current frame's info // against the previous one to determine if we need to recreate the color // conversion object. - // - // TODO: The names of these fields is confusing, as the actual color - // conversion object for Sws has "context" in the name, and we use - // "context" for the structs we store to know if we need to recreate a - // color conversion object. We should clean that up. std::unique_ptr filterGraph_; - FiltersContext prevFiltersContext_; - UniqueSwsContext swsContext_; - SwsFrameContext prevSwsFrameContext_; - - // The filter we supply to filterGraph_, if it is used. The default is the - // copy filter, which just copies the input to the output. Computationally, it - // should be a no-op. If we get no user-provided transforms, we will use the - // copy filter. Otherwise, we will construct the string from the transforms. + FiltersConfig prevFiltersConfig_; + std::unique_ptr swScale_; + + // Cached swscale context for encoding (tensor -> AVFrame pixel format + // conversion). + UniqueSwsContext encodingSwsContext_; + + // We pass these filters to FFmpeg's filtergraph API. It is a simple pipeline + // of what FFmpeg calls "filters" to apply to decoded frames before returning + // them. In the PyTorch ecosystem, we call these "transforms". During + // initialization, we convert the user-supplied transforms into this string of + // filters. // - // Note that even if we only use the copy filter, we still get the desired - // colorspace conversion. We construct the filtergraph with its output sink - // set to RGB24. + // Note that if there are no user-supplied transforms, then the default filter + // we use is the copy filter, which is just an identity: it emits the output + // frame unchanged. We supply such a filter because we can't supply just the + // empty-string; we must supply SOME filter. + // + // See also [Tranform and Format Conversion Order] for more on filters. std::string filters_ = "copy"; - // The flags we supply to swsContext_, if it used. The flags control the - // resizing algorithm. We default to bilinear. Users can override this with a - // ResizeTransform. - int swsFlags_ = SWS_BILINEAR; - // Values set during initialization and referred to in // getColorConversionLibrary(). bool areTransformsSwScaleCompatible_; bool userRequestedSwScale_; + // The flags we supply to the resize swscale context. The flags control the + // resizing algorithm. We default to bilinear. Users can override this with a + // ResizeTransform that specifies a different interpolation mode. + int swsFlags_ = SWS_BILINEAR; + bool initialized_ = false; + + // Audio-specific members + AudioStreamOptions audioStreamOptions_; + UniqueSwrContext swrContext_; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index be45050e6..ad664f206 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -1,11 +1,11 @@ -#include -#include -#include +#include #include -#include "src/torchcodec/_core/Cache.h" -#include "src/torchcodec/_core/CudaDeviceInterface.h" -#include "src/torchcodec/_core/FFMPEGCommon.h" +#include "Cache.h" +#include "CudaDeviceInterface.h" +#include "FFMPEGCommon.h" +#include "StableABICompat.h" +#include "ValidationUtils.h" extern "C" { #include @@ -16,10 +16,8 @@ namespace facebook::torchcodec { namespace { static bool g_cuda = registerDeviceInterface( - DeviceInterfaceKey(torch::kCUDA), - [](const torch::Device& device) { - return new CudaDeviceInterface(device); - }); + DeviceInterfaceKey(kStableCUDA), + [](const StableDevice& device) { return new CudaDeviceInterface(device); }); // We reuse cuda contexts across VideoDeoder instances. This is because // creating a cuda context is expensive. The cache mechanism is as follows: @@ -48,9 +46,9 @@ int getFlagsAVHardwareDeviceContextCreate() { #endif } -UniqueAVBufferRef getHardwareDeviceContext(const torch::Device& device) { +UniqueAVBufferRef getHardwareDeviceContext(const StableDevice& device) { enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda"); - TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device"); + STD_TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device"); int deviceIndex = getDeviceIndex(device); UniqueAVBufferRef hardwareDeviceCtx = g_cached_hw_device_ctxs.get(device); @@ -59,10 +57,10 @@ UniqueAVBufferRef getHardwareDeviceContext(const torch::Device& device) { } // Create hardware device context - c10::cuda::CUDAGuard deviceGuard(device); + StableDeviceGuard deviceGuard(device.index()); // We set the device because we may be called from a different thread than // the one that initialized the cuda context. - TORCH_CHECK( + STD_TORCH_CHECK( cudaSetDevice(deviceIndex) == cudaSuccess, "Failed to set CUDA device"); AVBufferRef* hardwareDeviceCtxRaw = nullptr; std::string deviceOrdinal = std::to_string(deviceIndex); @@ -76,7 +74,7 @@ UniqueAVBufferRef getHardwareDeviceContext(const torch::Device& device) { if (err < 0) { /* clang-format off */ - TORCH_CHECK( + STD_TORCH_CHECK( false, "Failed to create specified HW device. This typically happens when ", "your installed FFmpeg doesn't support CUDA (see ", @@ -90,11 +88,11 @@ UniqueAVBufferRef getHardwareDeviceContext(const torch::Device& device) { } // namespace -CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device) +CudaDeviceInterface::CudaDeviceInterface(const StableDevice& device) : DeviceInterface(device) { - TORCH_CHECK(g_cuda, "CudaDeviceInterface was not registered!"); - TORCH_CHECK( - device_.type() == torch::kCUDA, "Unsupported device: ", device_.str()); + STD_TORCH_CHECK(g_cuda, "CudaDeviceInterface was not registered!"); + STD_TORCH_CHECK( + device_.type() == kStableCUDA, "Unsupported device: must be CUDA"); initializeCudaContextWithPytorch(device_); @@ -114,13 +112,13 @@ void CudaDeviceInterface::initialize( const AVStream* avStream, const UniqueDecodingAVFormatContext& avFormatCtx, const SharedAVCodecContext& codecContext) { - TORCH_CHECK(avStream != nullptr, "avStream is null"); + STD_TORCH_CHECK(avStream != nullptr, "avStream is null"); codecContext_ = codecContext; timeBase_ = avStream->time_base; // TODO: Ideally, we should keep all interface implementations independent. - cpuInterface_ = createDeviceInterface(torch::kCPU); - TORCH_CHECK( + cpuInterface_ = createDeviceInterface(kStableCPU); + STD_TORCH_CHECK( cpuInterface_ != nullptr, "Failed to create CPU device interface"); cpuInterface_->initialize(avStream, avFormatCtx, codecContext); cpuInterface_->initializeVideo( @@ -138,9 +136,9 @@ void CudaDeviceInterface::initializeVideo( void CudaDeviceInterface::registerHardwareDeviceWithCodec( AVCodecContext* codecContext) { - TORCH_CHECK( + STD_TORCH_CHECK( hardwareDeviceCtx_, "Hardware device context has not been initialized"); - TORCH_CHECK(codecContext != nullptr, "codecContext is null"); + STD_TORCH_CHECK(codecContext != nullptr, "codecContext is null"); codecContext->hw_device_ctx = av_buffer_ref(hardwareDeviceCtx_.get()); } @@ -159,7 +157,7 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24( auto hwFramesCtx = reinterpret_cast(avFrame->hw_frames_ctx->data); - TORCH_CHECK( + STD_TORCH_CHECK( hwFramesCtx != nullptr, "The AVFrame does not have a hw_frames_ctx. " "That's unexpected, please report this to the TorchCodec repo."); @@ -183,7 +181,7 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24( outputFormat = AV_PIX_FMT_RGB24; auto actualFormatName = av_get_pix_fmt_name(actualFormat); - TORCH_CHECK( + STD_TORCH_CHECK( actualFormatName != nullptr, "The actual format of a frame is unknown to FFmpeg. " "That's unexpected, please report this to the TorchCodec repo."); @@ -199,7 +197,7 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24( enum AVPixelFormat frameFormat = static_cast(avFrame->format); - auto newContext = std::make_unique( + auto newConfig = std::make_unique( avFrame->width, avFrame->height, frameFormat, @@ -211,22 +209,22 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24( timeBase_, av_buffer_ref(avFrame->hw_frames_ctx)); - if (!nv12Conversion_ || *nv12ConversionContext_ != *newContext) { + if (!nv12Conversion_ || *nv12ConversionConfig_ != *newConfig) { nv12Conversion_ = - std::make_unique(*newContext, videoStreamOptions_); - nv12ConversionContext_ = std::move(newContext); + std::make_unique(*newConfig, videoStreamOptions_); + nv12ConversionConfig_ = std::move(newConfig); } auto filteredAVFrame = nv12Conversion_->convert(avFrame); // If this check fails it means the frame wasn't // reshaped to its expected dimensions by filtergraph. - TORCH_CHECK( - (filteredAVFrame->width == nv12ConversionContext_->outputWidth) && - (filteredAVFrame->height == nv12ConversionContext_->outputHeight), + STD_TORCH_CHECK( + (filteredAVFrame->width == nv12ConversionConfig_->outputWidth) && + (filteredAVFrame->height == nv12ConversionConfig_->outputHeight), "Expected frame from filter graph of ", - nv12ConversionContext_->outputWidth, + nv12ConversionConfig_->outputWidth, "x", - nv12ConversionContext_->outputHeight, + nv12ConversionConfig_->outputHeight, ", got ", filteredAVFrame->width, "x", @@ -238,9 +236,11 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24( void CudaDeviceInterface::convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, - std::optional preAllocatedOutputTensor) { + std::optional preAllocatedOutputTensor) { validatePreAllocatedTensorShape(preAllocatedOutputTensor, avFrame); + hasDecodedFrame_ = true; + // All of our CUDA decoding assumes NV12 format. We handle non-NV12 formats by // converting them to NV12. avFrame = maybeConvertAVFrameToNV12OrRGB24(avFrame); @@ -278,10 +278,11 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( // pre-allocated tensor is on the GPU, so we can't send that to the CPU // device interface. We copy it over here. if (preAllocatedOutputTensor.has_value()) { - preAllocatedOutputTensor.value().copy_(cpuFrameOutput.data); + torch::stable::copy_( + preAllocatedOutputTensor.value(), cpuFrameOutput.data); frameOutput.data = preAllocatedOutputTensor.value(); } else { - frameOutput.data = cpuFrameOutput.data.to(device_); + frameOutput.data = torch::stable::to(cpuFrameOutput.data, device_); } usingCPUFallback_ = true; @@ -294,17 +295,17 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( // also need to check that the AVFrame is in AV_PIX_FMT_NV12 format (8 bits), // because this is what the NPP color conversion routines expect. This SHOULD // be enforced by our call to maybeConvertAVFrameToNV12OrRGB24() above. - TORCH_CHECK( + STD_TORCH_CHECK( avFrame->hw_frames_ctx != nullptr, "The AVFrame does not have a hw_frames_ctx. This should never happen"); AVHWFramesContext* hwFramesCtx = reinterpret_cast(avFrame->hw_frames_ctx->data); - TORCH_CHECK( + STD_TORCH_CHECK( hwFramesCtx != nullptr, "The AVFrame does not have a valid hw_frames_ctx. This should never happen"); AVPixelFormat actualFormat = hwFramesCtx->sw_format; - TORCH_CHECK( + STD_TORCH_CHECK( actualFormat == AV_PIX_FMT_NV12, "The AVFrame is ", (av_get_pix_fmt_name(actualFormat) ? av_get_pix_fmt_name(actualFormat) @@ -316,14 +317,15 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( // In reality, we know that this stream is hardcoded to be the default stream // by FFmpeg: // https://github.com/FFmpeg/FFmpeg/blob/66e40840d15b514f275ce3ce2a4bf72ec68c7311/libavutil/hwcontext_cuda.c#L387-L388 - TORCH_CHECK( + STD_TORCH_CHECK( hwFramesCtx->device_ctx != nullptr, "The AVFrame's hw_frames_ctx does not have a device_ctx. "); auto cudaDeviceCtx = static_cast(hwFramesCtx->device_ctx->hwctx); - TORCH_CHECK(cudaDeviceCtx != nullptr, "The hardware context is null"); - at::cuda::CUDAStream nvdecStream = // That's always the default stream. Sad. - c10::cuda::getStreamFromExternal(cudaDeviceCtx->stream, device_.index()); + STD_TORCH_CHECK(cudaDeviceCtx != nullptr, "The hardware context is null"); + + cudaStream_t nvdecStream = // That's always the default stream. Sad. + cudaDeviceCtx->stream; frameOutput.data = convertNV12FrameToRGB( avFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor); @@ -334,12 +336,22 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( // appropriately set, so we just go off and find the matching codec for the CUDA // device std::optional CudaDeviceInterface::findCodec( - const AVCodecID& codecId) { + const AVCodecID& codecId, + bool isDecoder) { void* i = nullptr; const AVCodec* codec = nullptr; while ((codec = av_codec_iterate(&i)) != nullptr) { - if (codec->id != codecId || !av_codec_is_decoder(codec)) { - continue; + STD_TORCH_CHECK( + codec != nullptr, + "codec returned by av_codec_iterate should not be null"); + if (isDecoder) { + if (codec->id != codecId || !av_codec_is_decoder(codec)) { + continue; + } + } else { + if (codec->id != codecId || !av_codec_is_encoder(codec)) { + continue; + } } const AVCodecHWConfig* config = nullptr; @@ -358,8 +370,236 @@ std::string CudaDeviceInterface::getDetails() { // Note: for this interface specifically the fallback is only known after a // frame has been decoded, not before: that's when FFmpeg decides to fallback, // so we can't know earlier. + if (!hasDecodedFrame_) { + return std::string( + "FFmpeg CUDA Device Interface. Fallback status unknown (no frames decoded)."); + } return std::string("FFmpeg CUDA Device Interface. Using ") + (usingCPUFallback_ ? "CPU fallback." : "NVDEC."); } +// -------------------------------------------------------------------------- +// Below are methods exclusive to video encoding: +// -------------------------------------------------------------------------- +namespace { +// Note: [RGB -> YUV Color Conversion, limited color range] +// +// For context on this subject, first read the note: +// [YUV -> RGB Color Conversion, color space and color range] +// https://github.com/meta-pytorch/torchcodec/blob/main/src/torchcodec/_core/CUDACommon.cpp#L63-L65 +// +// Lets encode RGB -> YUV in the limited color range for BT.601 color space. +// In limited range, the [0, 255] range is mapped into [16-235] for Y, and into +// [16-240] for U,V. +// To implement, we get the full range conversion matrix as before, then scale: +// - Y channel: scale by (235-16)/255 = 219/255 +// - U,V channels: scale by (240-16)/255 = 224/255 +// https://en.wikipedia.org/wiki/YCbCr#Y%E2%80%B2PbPr_to_Y%E2%80%B2CbCr +// +// ```py +// import torch +// kr, kg, kb = 0.299, 0.587, 0.114 # BT.601 luma coefficients +// u_scale = 2 * (1 - kb) +// v_scale = 2 * (1 - kr) +// +// rgb_to_yuv_full = torch.tensor([ +// [kr, kg, kb], +// [-kr/u_scale, -kg/u_scale, (1-kb)/u_scale], +// [(1-kr)/v_scale, -kg/v_scale, -kb/v_scale] +// ]) +// +// full_to_limited_y_scale = 219.0 / 255.0 +// full_to_limited_uv_scale = 224.0 / 255.0 +// +// rgb_to_yuv_limited = rgb_to_yuv_full * torch.tensor([ +// [full_to_limited_y_scale], +// [full_to_limited_uv_scale], +// [full_to_limited_uv_scale] +// ]) +// +// print("RGB->YUV matrix (Limited Range BT.601):") +// print(rgb_to_yuv_limited) +// ``` +// +// This yields: +// tensor([[ 0.2568, 0.5041, 0.0979], +// [-0.1482, -0.2910, 0.4392], +// [ 0.4392, -0.3678, -0.0714]]) +// +// Which matches https://fourcc.org/fccyvrgb.php +// +// To perform color conversion in NPP, we are required to provide these color +// conversion matrices to ColorTwist functions, for example, +// `nppiRGBToNV12_8u_ColorTwist32f_C3P2R_Ctx`. +// https://docs.nvidia.com/cuda/npp/image_color_conversion.html +// +// These offsets are added in the 4th column of each conversion matrix below. +// - In limited range, Y is offset by 16 to add the lower margin. +// - In both color ranges, U,V are offset by 128 to be centered around 0. +// +// RGB to YUV conversion matrices to use in NPP color conversion functions +struct ColorConversionMatrices { + static constexpr Npp32f BT601_LIMITED[3][4] = { + {0.2568f, 0.5041f, 0.0979f, 16.0f}, + {-0.1482f, -0.2910f, 0.4392f, 128.0f}, + {0.4392f, -0.3678f, -0.0714f, 128.0f}}; + + static constexpr Npp32f BT601_FULL[3][4] = { + {0.2990f, 0.5870f, 0.1140f, 0.0f}, + {-0.1687f, -0.3313f, 0.5000f, 128.0f}, + {0.5000f, -0.4187f, -0.0813f, 128.0f}}; + + static constexpr Npp32f BT709_LIMITED[3][4] = { + {0.1826f, 0.6142f, 0.0620f, 16.0f}, + {-0.1006f, -0.3386f, 0.4392f, 128.0f}, + {0.4392f, -0.3989f, -0.0403f, 128.0f}}; + + static constexpr Npp32f BT709_FULL[3][4] = { + {0.2126f, 0.7152f, 0.0722f, 0.0f}, + {-0.1146f, -0.3854f, 0.5000f, 128.0f}, + {0.5000f, -0.4542f, -0.0458f, 128.0f}}; + + static constexpr Npp32f BT2020_LIMITED[3][4] = { + {0.2256f, 0.5823f, 0.0509f, 16.0f}, + {-0.1227f, -0.3166f, 0.4392f, 128.0f}, + {0.4392f, -0.4039f, -0.0353f, 128.0f}}; + + static constexpr Npp32f BT2020_FULL[3][4] = { + {0.2627f, 0.6780f, 0.0593f, 0.0f}, + {-0.139630f, -0.360370f, 0.5000f, 128.0f}, + {0.5000f, -0.459786f, -0.040214f, 128.0f}}; +}; + +// Returns conversion matrix based on codec context color space and range +const Npp32f (*getConversionMatrix(AVCodecContext* codecContext))[4] { + if (codecContext->color_range == AVCOL_RANGE_MPEG || // limited range + codecContext->color_range == AVCOL_RANGE_UNSPECIFIED) { + if (codecContext->colorspace == AVCOL_SPC_BT470BG) { + return ColorConversionMatrices::BT601_LIMITED; + } else if (codecContext->colorspace == AVCOL_SPC_BT709) { + return ColorConversionMatrices::BT709_LIMITED; + } else if (codecContext->colorspace == AVCOL_SPC_BT2020_NCL) { + return ColorConversionMatrices::BT2020_LIMITED; + } else { // default to BT.601 + return ColorConversionMatrices::BT601_LIMITED; + } + } else if (codecContext->color_range == AVCOL_RANGE_JPEG) { // full range + if (codecContext->colorspace == AVCOL_SPC_BT470BG) { + return ColorConversionMatrices::BT601_FULL; + } else if (codecContext->colorspace == AVCOL_SPC_BT709) { + return ColorConversionMatrices::BT709_FULL; + } else if (codecContext->colorspace == AVCOL_SPC_BT2020_NCL) { + return ColorConversionMatrices::BT2020_FULL; + } else { // default to BT.601 + return ColorConversionMatrices::BT601_FULL; + } + } + return ColorConversionMatrices::BT601_LIMITED; +} +} // namespace + +UniqueAVFrame CudaDeviceInterface::convertTensorToAVFrameForEncoding( + const torch::stable::Tensor& tensor, + int frameIndex, + AVCodecContext* codecContext) { + STD_TORCH_CHECK( + tensor.dim() == 3 && tensor.sizes()[0] == 3, + "Expected 3D RGB tensor (CHW format), got ", + tensor.dim(), + "D tensor"); + STD_TORCH_CHECK( + tensor.device().type() == kStableCUDA, + "Expected tensor on CUDA device, got: ", + deviceTypeName(tensor.device().type())); + + UniqueAVFrame avFrame(av_frame_alloc()); + STD_TORCH_CHECK(avFrame != nullptr, "Failed to allocate AVFrame"); + int height = static_cast(tensor.sizes()[1]); + int width = static_cast(tensor.sizes()[2]); + + // TODO-VideoEncoder: (P1) Unify AVFrame creation with CPU method + avFrame->format = AV_PIX_FMT_CUDA; + avFrame->height = height; + avFrame->width = width; + avFrame->pts = frameIndex; + + // FFmpeg's av_hwframe_get_buffer is used to allocate memory on CUDA device. + // TODO-VideoEncoder: (P2) Consider using pytorch to allocate CUDA memory for + // efficiency + int ret = + av_hwframe_get_buffer(codecContext->hw_frames_ctx, avFrame.get(), 0); + STD_TORCH_CHECK( + ret >= 0, + "Failed to allocate hardware frame: ", + getFFMPEGErrorStringFromErrorCode(ret)); + + STD_TORCH_CHECK( + avFrame != nullptr && avFrame->data[0] != nullptr, + "avFrame must be pre-allocated with CUDA memory"); + + // TODO VideoEncoder: Investigate ways to avoid this copy + torch::stable::Tensor hwcFrame = + torch::stable::contiguous(stablePermute(tensor, {1, 2, 0})); + + NppiSize oSizeROI = {width, height}; + NppStatus status; + // Convert to NV12, as CUDA_ENCODING_PIXEL_FORMAT is always NV12 currently + status = nppiRGBToNV12_8u_ColorTwist32f_C3P2R_Ctx( + hwcFrame.const_data_ptr(), + validateInt64ToInt( + hwcFrame.stride(0) * static_cast(hwcFrame.element_size()), + "nSrcStep"), + avFrame->data, + avFrame->linesize, + oSizeROI, + getConversionMatrix(codecContext), + *nppCtx_); + + STD_TORCH_CHECK( + status == NPP_SUCCESS, + "Failed to convert RGB to ", + av_get_pix_fmt_name(DeviceInterface::CUDA_ENCODING_PIXEL_FORMAT), + ": NPP error code ", + status); + + avFrame->colorspace = codecContext->colorspace; + avFrame->color_range = codecContext->color_range; + return avFrame; +} + +// Allocates and initializes AVHWFramesContext, and sets pixel format fields +// to enable encoding with CUDA device. The hw_frames_ctx field is needed by +// FFmpeg to allocate frames on GPU's memory. +void CudaDeviceInterface::setupHardwareFrameContextForEncoding( + AVCodecContext* codecContext) { + STD_TORCH_CHECK(codecContext != nullptr, "codecContext is null"); + STD_TORCH_CHECK( + hardwareDeviceCtx_, "Hardware device context has not been initialized"); + + AVBufferRef* hwFramesCtxRef = av_hwframe_ctx_alloc(hardwareDeviceCtx_.get()); + STD_TORCH_CHECK( + hwFramesCtxRef != nullptr, + "Failed to allocate hardware frames context for codec"); + + codecContext->sw_pix_fmt = DeviceInterface::CUDA_ENCODING_PIXEL_FORMAT; + // Always set pixel format to support CUDA encoding. + codecContext->pix_fmt = AV_PIX_FMT_CUDA; + + AVHWFramesContext* hwFramesCtx = + reinterpret_cast(hwFramesCtxRef->data); + hwFramesCtx->format = codecContext->pix_fmt; + hwFramesCtx->sw_format = codecContext->sw_pix_fmt; + hwFramesCtx->width = codecContext->width; + hwFramesCtx->height = codecContext->height; + + int ret = av_hwframe_ctx_init(hwFramesCtxRef); + if (ret < 0) { + av_buffer_unref(&hwFramesCtxRef); + STD_TORCH_CHECK( + false, + "Failed to initialize CUDA frames context for codec: ", + getFFMPEGErrorStringFromErrorCode(ret)); + } + codecContext->hw_frames_ctx = hwFramesCtxRef; +} } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h index 9f171ee3c..d660559a1 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.h +++ b/src/torchcodec/_core/CudaDeviceInterface.h @@ -6,19 +6,21 @@ #pragma once -#include "src/torchcodec/_core/CUDACommon.h" -#include "src/torchcodec/_core/DeviceInterface.h" -#include "src/torchcodec/_core/FilterGraph.h" +#include "CUDACommon.h" +#include "DeviceInterface.h" +#include "FilterGraph.h" namespace facebook::torchcodec { class CudaDeviceInterface : public DeviceInterface { public: - CudaDeviceInterface(const torch::Device& device); + CudaDeviceInterface(const StableDevice& device); virtual ~CudaDeviceInterface(); - std::optional findCodec(const AVCodecID& codecId) override; + std::optional findCodec( + const AVCodecID& codecId, + bool isDecoder = true) override; void initialize( const AVStream* avStream, @@ -37,11 +39,18 @@ class CudaDeviceInterface : public DeviceInterface { void convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, - std::optional preAllocatedOutputTensor = - std::nullopt) override; + std::optional preAllocatedOutputTensor) override; std::string getDetails() override; + UniqueAVFrame convertTensorToAVFrameForEncoding( + const torch::stable::Tensor& tensor, + int frameIndex, + AVCodecContext* codecContext) override; + + void setupHardwareFrameContextForEncoding( + AVCodecContext* codecContext) override; + private: // Our CUDA decoding code assumes NV12 format. In order to handle other // kinds of input, we need to convert them to NV12. Our current implementation @@ -60,10 +69,11 @@ class CudaDeviceInterface : public DeviceInterface { // This filtergraph instance is only used for NV12 format conversion in // maybeConvertAVFrameToNV12(). - std::unique_ptr nv12ConversionContext_; + std::unique_ptr nv12ConversionConfig_; std::unique_ptr nv12Conversion_; bool usingCPUFallback_ = false; + bool hasDecodedFrame_ = false; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/DeviceInterface.cpp b/src/torchcodec/_core/DeviceInterface.cpp index 2f910e998..d26380180 100644 --- a/src/torchcodec/_core/DeviceInterface.cpp +++ b/src/torchcodec/_core/DeviceInterface.cpp @@ -4,9 +4,11 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. -#include "src/torchcodec/_core/DeviceInterface.h" +#include "DeviceInterface.h" +#include #include #include +#include "StableABICompat.h" namespace facebook::torchcodec { @@ -20,7 +22,7 @@ DeviceInterfaceMap& getDeviceMap() { return deviceMap; } -std::string getDeviceType(const std::string& device) { +std::string getDeviceTypeString(const std::string& device) { size_t pos = device.find(':'); if (pos == std::string::npos) { return device; @@ -28,6 +30,22 @@ std::string getDeviceType(const std::string& device) { return device.substr(0, pos); } +// Parse device type from string (e.g., "cpu", "cuda") +// TODO_STABLE_ABI: we might need to support more device types, i.e. those from +// https://github.com/pytorch/pytorch/blob/main/torch/headeronly/core/DeviceType.h +// Ideally we'd remove this helper? +StableDeviceType parseDeviceType(const std::string& deviceType) { + if (deviceType == "cpu") { + return kStableCPU; + } else if (deviceType == "cuda") { + return kStableCUDA; + } else if (deviceType == "xpu") { + return kStableXPU; + } else { + STD_TORCH_CHECK(false, "Unknown device type: ", deviceType); + } +} + } // namespace bool registerDeviceInterface( @@ -36,10 +54,10 @@ bool registerDeviceInterface( std::scoped_lock lock(g_interface_mutex); DeviceInterfaceMap& deviceMap = getDeviceMap(); - TORCH_CHECK( + STD_TORCH_CHECK( deviceMap.find(key) == deviceMap.end(), "Device interface already registered for device type ", - key.deviceType, + static_cast(key.deviceType), " variant '", key.variant, "'"); @@ -49,15 +67,15 @@ bool registerDeviceInterface( } void validateDeviceInterface( - const std::string device, - const std::string variant) { + const std::string& device, + const std::string& variant) { std::scoped_lock lock(g_interface_mutex); - std::string deviceType = getDeviceType(device); + std::string deviceType = getDeviceTypeString(device); DeviceInterfaceMap& deviceMap = getDeviceMap(); // Find device interface that matches device type and variant - torch::DeviceType deviceTypeEnum = torch::Device(deviceType).type(); + StableDeviceType deviceTypeEnum = parseDeviceType(deviceType); auto deviceInterface = std::find_if( deviceMap.begin(), @@ -67,7 +85,7 @@ void validateDeviceInterface( arg.first.variant == variant; }); - TORCH_CHECK( + STD_TORCH_CHECK( deviceInterface != deviceMap.end(), "Unsupported device: ", device, @@ -79,7 +97,7 @@ void validateDeviceInterface( } std::unique_ptr createDeviceInterface( - const torch::Device& device, + const StableDevice& device, const std::string_view variant) { DeviceInterfaceKey key(device.type(), variant); std::scoped_lock lock(g_interface_mutex); @@ -90,28 +108,40 @@ std::unique_ptr createDeviceInterface( return std::unique_ptr(it->second(device)); } - TORCH_CHECK( + STD_TORCH_CHECK( false, "No device interface found for device type: ", - device.type(), + static_cast(device.type()), " variant: '", variant, "'"); } -torch::Tensor rgbAVFrameToTensor(const UniqueAVFrame& avFrame) { - TORCH_CHECK_EQ(avFrame->format, AV_PIX_FMT_RGB24); +torch::stable::Tensor rgbAVFrameToTensor(const UniqueAVFrame& avFrame) { + STD_TORCH_CHECK(avFrame->format == AV_PIX_FMT_RGB24, "Expected RGB24 format"); int height = avFrame->height; int width = avFrame->width; std::vector shape = {height, width, 3}; std::vector strides = {avFrame->linesize[0], 3, 1}; AVFrame* avFrameClone = av_frame_clone(avFrame.get()); + + // TODO_STABLE_ABI: we're still using the non-stable ABI here. That's because + // stable::from_blob doesn't yet support a capturing lambda deleter. We need + // to land https://github.com/pytorch/pytorch/pull/175089. + // TC won't be able stable until this is resolved. auto deleter = [avFrameClone](void*) { UniqueAVFrame avFrameToDelete(avFrameClone); }; - return torch::from_blob( - avFrameClone->data[0], shape, strides, deleter, {torch::kUInt8}); + + at::Tensor tensor = at::from_blob( + avFrameClone->data[0], shape, strides, deleter, {at::kByte}); + + // We got an at::Tensor, we have to convert it to a torch::stable::Tensor. + // This is safe, there won't be any memory leak, i.e. the at::Tensor's deleter + // will properly be passed down to the torch::stable::Tensor. + at::Tensor* p = new at::Tensor(std::move(tensor)); + return torch::stable::Tensor(reinterpret_cast(p)); } } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index 773317e83..6b5388d26 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -6,21 +6,21 @@ #pragma once -#include #include #include #include #include #include "FFMPEGCommon.h" -#include "src/torchcodec/_core/Frame.h" -#include "src/torchcodec/_core/StreamOptions.h" -#include "src/torchcodec/_core/Transform.h" +#include "Frame.h" +#include "StableABICompat.h" +#include "StreamOptions.h" +#include "Transform.h" namespace facebook::torchcodec { // Key for device interface registration with device type + variant support struct DeviceInterfaceKey { - torch::DeviceType deviceType; + StableDeviceType deviceType; std::string_view variant = "ffmpeg"; // e.g., "ffmpeg", "beta", etc. bool operator<(const DeviceInterfaceKey& other) const { @@ -30,24 +30,25 @@ struct DeviceInterfaceKey { return variant < other.variant; } - explicit DeviceInterfaceKey(torch::DeviceType type) : deviceType(type) {} + explicit DeviceInterfaceKey(StableDeviceType type) : deviceType(type) {} - DeviceInterfaceKey(torch::DeviceType type, const std::string_view& variant) + DeviceInterfaceKey(StableDeviceType type, const std::string_view& variant) : deviceType(type), variant(variant) {} }; class DeviceInterface { public: - DeviceInterface(const torch::Device& device) : device_(device) {} + DeviceInterface(const StableDevice& device) : device_(device) {} virtual ~DeviceInterface(){}; - torch::Device& device() { + StableDevice& device() { return device_; }; virtual std::optional findCodec( - [[maybe_unused]] const AVCodecID& codecId) { + [[maybe_unused]] const AVCodecID& codecId, + [[maybe_unused]] bool isDecoder = true) { return std::nullopt; }; @@ -65,6 +66,21 @@ class DeviceInterface { transforms, [[maybe_unused]] const std::optional& resizedOutputDims) {} + // Initialize the device with parameters specific to audio decoding. There is + // a default empty implementation. + virtual void initializeAudio( + [[maybe_unused]] const AudioStreamOptions& audioStreamOptions) {} + + // Flush any remaining samples from the audio resampler buffer. + // When sample rate conversion is involved, some samples may be buffered + // between frames for proper interpolation. This function flushes those + // buffered samples. + // Returns an optional tensor containing the flushed samples, or std::nullopt + // if there are no buffered samples or audio is not supported. + virtual std::optional maybeFlushAudioBuffers() { + return std::nullopt; + } + // In order for decoding to actually happen on an FFmpeg managed hardware // device, we need to register the DeviceInterface managed // AVHardwareDeviceContext with the AVCodecContext. We don't need to do this @@ -75,7 +91,8 @@ class DeviceInterface { virtual void convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, - std::optional preAllocatedOutputTensor = std::nullopt) = 0; + std::optional preAllocatedOutputTensor = + std::nullopt) = 0; // ------------------------------------------ // Extension points for custom decoding paths @@ -85,7 +102,7 @@ class DeviceInterface { // other AVERROR on failure // Default implementation uses FFmpeg directly virtual int sendPacket(ReferenceAVPacket& avPacket) { - TORCH_CHECK( + STD_TORCH_CHECK( codecContext_ != nullptr, "Codec context not available for default packet sending"); return avcodec_send_packet(codecContext_.get(), avPacket.get()); @@ -95,7 +112,7 @@ class DeviceInterface { // Returns AVSUCCESS on success, or other AVERROR on failure // Default implementation uses FFmpeg directly virtual int sendEOFPacket() { - TORCH_CHECK( + STD_TORCH_CHECK( codecContext_ != nullptr, "Codec context not available for default EOF packet sending"); return avcodec_send_packet(codecContext_.get(), nullptr); @@ -105,7 +122,7 @@ class DeviceInterface { // AVERROR_EOF if end of stream, or other AVERROR on failure // Default implementation uses FFmpeg directly virtual int receiveFrame(UniqueAVFrame& avFrame) { - TORCH_CHECK( + STD_TORCH_CHECK( codecContext_ != nullptr, "Codec context not available for default frame receiving"); return avcodec_receive_frame(codecContext_.get(), avFrame.get()); @@ -113,7 +130,7 @@ class DeviceInterface { // Flush remaining frames from decoder virtual void flush() { - TORCH_CHECK( + STD_TORCH_CHECK( codecContext_ != nullptr, "Codec context not available for default flushing"); avcodec_flush_buffers(codecContext_.get()); @@ -123,26 +140,50 @@ class DeviceInterface { return ""; } + // Pixel format used for encoding on CUDA devices + static constexpr AVPixelFormat CUDA_ENCODING_PIXEL_FORMAT = AV_PIX_FMT_NV12; + + virtual UniqueAVFrame convertTensorToAVFrameForEncoding( + [[maybe_unused]] const torch::stable::Tensor& tensor, + [[maybe_unused]] int frameIndex, + [[maybe_unused]] AVCodecContext* codecContext) { + STD_TORCH_CHECK(false, "convertTensorToAVFrameForEncoding not implemented"); + } + + // Function used for video encoding, only implemented in CudaDeviceInterface. + // It is here to isolate CUDA dependencies from CPU builds + virtual void setupHardwareFrameContextForEncoding( + [[maybe_unused]] AVCodecContext* codecContext) { + STD_TORCH_CHECK( + false, "setupHardwareFrameContextForEncoding not implemented"); + } + + virtual std::optional findHardwareEncoder( + [[maybe_unused]] const AVCodecID& codecId) { + STD_TORCH_CHECK(false, "findHardwareEncoder not implemented"); + } + protected: - torch::Device device_; + StableDevice device_; SharedAVCodecContext codecContext_; + AVMediaType avMediaType_; }; using CreateDeviceInterfaceFn = - std::function; + std::function; -bool registerDeviceInterface( +TORCHCODEC_THIRD_PARTY_API bool registerDeviceInterface( const DeviceInterfaceKey& key, const CreateDeviceInterfaceFn createInterface); -void validateDeviceInterface( - const std::string device, - const std::string variant); +FORCE_PUBLIC_VISIBILITY void validateDeviceInterface( + const std::string& device, + const std::string& variant); std::unique_ptr createDeviceInterface( - const torch::Device& device, + const StableDevice& device, const std::string_view variant = "ffmpeg"); -torch::Tensor rgbAVFrameToTensor(const UniqueAVFrame& avFrame); +torch::stable::Tensor rgbAVFrameToTensor(const UniqueAVFrame& avFrame); } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 4e5d6a604..7df2f48b0 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -1,19 +1,25 @@ #include -#include "src/torchcodec/_core/AVIOTensorContext.h" -#include "src/torchcodec/_core/Encoder.h" -#include "torch/types.h" +#include "AVIOTensorContext.h" +#include "Encoder.h" +#include "StableABICompat.h" + +extern "C" { +#include +#include +#include +} namespace facebook::torchcodec { namespace { -torch::Tensor validateSamples(const torch::Tensor& samples) { - TORCH_CHECK( - samples.dtype() == torch::kFloat32, +torch::stable::Tensor validateSamples(const torch::stable::Tensor& samples) { + STD_TORCH_CHECK( + samples.scalar_type() == kStableFloat32, "samples must have float32 dtype, got ", - samples.dtype()); - TORCH_CHECK( + (samples.scalar_type())); + STD_TORCH_CHECK( samples.dim() == 2, "samples must have 2 dimensions, got ", samples.dim()); @@ -21,7 +27,7 @@ torch::Tensor validateSamples(const torch::Tensor& samples) { // We enforce this, but if we get user reports we should investigate whether // that's actually needed. int numChannels = static_cast(samples.sizes()[0]); - TORCH_CHECK( + STD_TORCH_CHECK( numChannels <= AV_NUM_DATA_POINTERS, "Trying to encode ", numChannels, @@ -29,7 +35,7 @@ torch::Tensor validateSamples(const torch::Tensor& samples) { AV_NUM_DATA_POINTERS, " channels per frame."); - return samples.contiguous(); + return torch::stable::contiguous(samples); } void validateSampleRate(const AVCodec& avCodec, int sampleRate) { @@ -51,7 +57,7 @@ void validateSampleRate(const AVCodec& avCodec, int sampleRate) { supportedRates << supportedSampleRates[i]; } - TORCH_CHECK( + STD_TORCH_CHECK( false, "invalid sample rate=", sampleRate, @@ -100,30 +106,34 @@ AVSampleFormat findBestOutputSampleFormat(const AVCodec& avCodec) { return supportedSampleFormats[0]; } -} // namespace - -AudioEncoder::~AudioEncoder() { - close_avio(); -} +void closeAVIOContext( + AVFormatContext* avFormatContext, + AVIOContextHolder* avioContextHolder) { + if (!avFormatContext || !avFormatContext->pb) { + return; + } -void AudioEncoder::close_avio() { - if (avFormatContext_ && avFormatContext_->pb) { - if (avFormatContext_->pb->error == 0) { - avio_flush(avFormatContext_->pb); - } + if (avFormatContext->pb->error == 0) { + avio_flush(avFormatContext->pb); + } - if (!avioContextHolder_) { - if (avFormatContext_->pb->error == 0) { - avio_close(avFormatContext_->pb); - } - // avoids closing again in destructor, which would segfault. - avFormatContext_->pb = nullptr; + if (!avioContextHolder) { + if (avFormatContext->pb->error == 0) { + avio_close(avFormatContext->pb); } } + + avFormatContext->pb = nullptr; +} + +} // namespace + +AudioEncoder::~AudioEncoder() { + closeAVIOContext(avFormatContext_.get(), avioContextHolder_.get()); } AudioEncoder::AudioEncoder( - const torch::Tensor& samples, + const torch::stable::Tensor& samples, int sampleRate, std::string_view fileName, const AudioStreamOptions& audioStreamOptions) @@ -133,7 +143,7 @@ AudioEncoder::AudioEncoder( int status = avformat_alloc_output_context2( &avFormatContext, nullptr, nullptr, fileName.data()); - TORCH_CHECK( + STD_TORCH_CHECK( avFormatContext != nullptr, "Couldn't allocate AVFormatContext. ", "The destination file is ", @@ -143,7 +153,7 @@ AudioEncoder::AudioEncoder( avFormatContext_.reset(avFormatContext); status = avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE); - TORCH_CHECK( + STD_TORCH_CHECK( status >= 0, "avio_open failed. The destination file is ", fileName, @@ -154,7 +164,7 @@ AudioEncoder::AudioEncoder( } AudioEncoder::AudioEncoder( - const torch::Tensor& samples, + const torch::stable::Tensor& samples, int sampleRate, std::string_view formatName, std::unique_ptr avioContextHolder, @@ -167,7 +177,7 @@ AudioEncoder::AudioEncoder( int status = avformat_alloc_output_context2( &avFormatContext, nullptr, formatName.data(), nullptr); - TORCH_CHECK( + STD_TORCH_CHECK( avFormatContext != nullptr, "Couldn't allocate AVFormatContext. ", "Check the desired format? Got format=", @@ -187,15 +197,16 @@ void AudioEncoder::initializeEncoder( // specific format/container. const AVCodec* avCodec = avcodec_find_encoder(avFormatContext_->oformat->audio_codec); - TORCH_CHECK(avCodec != nullptr, "Codec not found"); + STD_TORCH_CHECK(avCodec != nullptr, "Codec not found"); AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec); - TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context."); + STD_TORCH_CHECK( + avCodecContext != nullptr, "Couldn't allocate codec context."); avCodecContext_.reset(avCodecContext); auto desiredBitRate = audioStreamOptions.bitRate; if (desiredBitRate.has_value()) { - TORCH_CHECK( + STD_TORCH_CHECK( *desiredBitRate >= 0, "bit_rate=", *desiredBitRate, " must be >= 0."); } // bit_rate=None defaults to 0, which is what the FFmpeg CLI seems to use as @@ -219,7 +230,7 @@ void AudioEncoder::initializeEncoder( avCodecContext_->sample_fmt = findBestOutputSampleFormat(*avCodec); int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr); - TORCH_CHECK( + STD_TORCH_CHECK( status == AVSUCCESS, "avcodec_open2 failed: ", getFFMPEGErrorStringFromErrorCode(status)); @@ -228,10 +239,10 @@ void AudioEncoder::initializeEncoder( // avformat_free_context(avFormatContext), which we call in the // avFormatContext_'s destructor. AVStream* avStream = avformat_new_stream(avFormatContext_.get(), nullptr); - TORCH_CHECK(avStream != nullptr, "Couldn't create new stream."); + STD_TORCH_CHECK(avStream != nullptr, "Couldn't create new stream."); status = avcodec_parameters_from_context( avStream->codecpar, avCodecContext_.get()); - TORCH_CHECK( + STD_TORCH_CHECK( status == AVSUCCESS, "avcodec_parameters_from_context failed: ", getFFMPEGErrorStringFromErrorCode(status)); @@ -248,19 +259,20 @@ void AudioEncoder::initializeEncoder( avCodecContext_->sample_fmt, outNumChannels_, avCodecContext_->frame_size * 2); - TORCH_CHECK(avAudioFifo != nullptr, "Couldn't create AVAudioFifo."); + STD_TORCH_CHECK(avAudioFifo != nullptr, "Couldn't create AVAudioFifo."); avAudioFifo_.reset(avAudioFifo); } } -torch::Tensor AudioEncoder::encodeToTensor() { - TORCH_CHECK( +torch::stable::Tensor AudioEncoder::encodeToTensor() { + STD_TORCH_CHECK( avioContextHolder_ != nullptr, "Cannot encode to tensor, avio tensor context doesn't exist."); encode(); auto avioToTensorContext = dynamic_cast(avioContextHolder_.get()); - TORCH_CHECK(avioToTensorContext != nullptr, "Invalid AVIO context holder."); + STD_TORCH_CHECK( + avioToTensorContext != nullptr, "Invalid AVIO context holder."); return avioToTensorContext->getOutputTensor(); } @@ -268,7 +280,7 @@ void AudioEncoder::encode() { // To be on the safe side we enforce that encode() can only be called once on // an encoder object. Whether this is actually necessary is unknown, so this // may be relaxed if needed. - TORCH_CHECK(!encodeWasCalled_, "Cannot call encode() twice."); + STD_TORCH_CHECK(!encodeWasCalled_, "Cannot call encode() twice."); encodeWasCalled_ = true; // Default to 256 like in torchaudio @@ -283,14 +295,15 @@ void AudioEncoder::encode() { AutoAVPacket autoAVPacket; - uint8_t* psamples = static_cast(samples_.data_ptr()); + const uint8_t* psamples = + static_cast(samples_.const_data_ptr()); int numSamples = static_cast(samples_.sizes()[1]); // per channel int numEncodedSamples = 0; // per channel int numBytesPerSample = static_cast(samples_.element_size()); int numBytesPerChannel = numSamples * numBytesPerSample; auto status = avformat_write_header(avFormatContext_.get(), nullptr); - TORCH_CHECK( + STD_TORCH_CHECK( status == AVSUCCESS, "Error in avformat_write_header: ", getFFMPEGErrorStringFromErrorCode(status)); @@ -320,17 +333,18 @@ void AudioEncoder::encode() { numEncodedSamples += numSamplesToEncode; } - TORCH_CHECK(numEncodedSamples == numSamples, "Hmmmmmm something went wrong."); + STD_TORCH_CHECK( + numEncodedSamples == numSamples, "Hmmmmmm something went wrong."); flushBuffers(); status = av_write_trailer(avFormatContext_.get()); - TORCH_CHECK( + STD_TORCH_CHECK( status == AVSUCCESS, "Error in: av_write_trailer", getFFMPEGErrorStringFromErrorCode(status)); - close_avio(); + closeAVIOContext(avFormatContext_.get(), avioContextHolder_.get()); } UniqueAVFrame AudioEncoder::maybeConvertAVFrame(const UniqueAVFrame& avFrame) { @@ -351,6 +365,12 @@ UniqueAVFrame AudioEncoder::maybeConvertAVFrame(const UniqueAVFrame& avFrame) { avFrame, outNumChannels_)); } + // convertAudioAVFrameSamples uses avFrame's extended_data field, so we ensure + // it's the same as data. This should always be the case since we validated + // earlier that we have less than AV_NUM_DATA_POINTERS channels. + STD_TORCH_CHECK( + avFrame->data == avFrame->extended_data, + "Codec context data and extended_data pointers differ, this is unexpected."); UniqueAVFrame convertedAVFrame = convertAudioAVFrameSamples( swrContext_, avFrame, @@ -359,7 +379,7 @@ UniqueAVFrame AudioEncoder::maybeConvertAVFrame(const UniqueAVFrame& avFrame) { outNumChannels_); if (avFrame->sample_rate == outSampleRate_) { - TORCH_CHECK( + STD_TORCH_CHECK( convertedAVFrame->nb_samples == avFrame->nb_samples, "convertedAVFrame->nb_samples=", convertedAVFrame->nb_samples, @@ -386,7 +406,7 @@ void AudioEncoder::encodeFrameThroughFifo( avAudioFifo_.get(), reinterpret_cast(avFrame->data), avFrame->nb_samples); - TORCH_CHECK( + STD_TORCH_CHECK( numSamplesWritten == avFrame->nb_samples, "Tried to write ", avFrame->nb_samples, @@ -419,7 +439,7 @@ void AudioEncoder::encodeFrameThroughFifo( avAudioFifo_.get(), reinterpret_cast(newavFrame->data), samplesToRead); - TORCH_CHECK( + STD_TORCH_CHECK( numSamplesRead == samplesToRead, "Tried to read ", samplesToRead, @@ -440,7 +460,7 @@ void AudioEncoder::encodeFrame( } auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get()); - TORCH_CHECK( + STD_TORCH_CHECK( status == AVSUCCESS, "Error while sending frame: ", getFFMPEGErrorStringFromErrorCode(status)); @@ -455,14 +475,14 @@ void AudioEncoder::encodeFrame( // TorchAudio: // https://github.com/pytorch/audio/blob/d60ce09e2c532d5bf2e05619e700ab520543465e/src/libtorio/ffmpeg/stream_writer/encoder.cpp#L21 status = av_interleaved_write_frame(avFormatContext_.get(), nullptr); - TORCH_CHECK( + STD_TORCH_CHECK( status == AVSUCCESS, "Failed to flush packet: ", getFFMPEGErrorStringFromErrorCode(status)); } return; } - TORCH_CHECK( + STD_TORCH_CHECK( status >= 0, "Error receiving packet: ", getFFMPEGErrorStringFromErrorCode(status)); @@ -470,7 +490,7 @@ void AudioEncoder::encodeFrame( packet->stream_index = streamIndex_; status = av_interleaved_write_frame(avFormatContext_.get(), packet.get()); - TORCH_CHECK( + STD_TORCH_CHECK( status == AVSUCCESS, "Error in av_interleaved_write_frame: ", getFFMPEGErrorStringFromErrorCode(status)); @@ -484,7 +504,7 @@ void AudioEncoder::maybeFlushSwrBuffers(AutoAVPacket& autoAVPacket) { if (swrContext_ == nullptr && inSampleRate_ == outSampleRate_) { return; } - TORCH_CHECK( + STD_TORCH_CHECK( swrContext_ != nullptr, "swrContext is null, but sample rate conversion is needed. ", "This is unexpected, please report on the TorchCodec bug tracker."); @@ -518,42 +538,167 @@ void AudioEncoder::flushBuffers() { namespace { -torch::Tensor validateFrames(const torch::Tensor& frames) { - TORCH_CHECK( - frames.dtype() == torch::kUInt8, +torch::stable::Tensor validateFrames( + const torch::stable::Tensor& frames, + const AVCodecContext* avCodecContext = nullptr) { + STD_TORCH_CHECK( + frames.scalar_type() == kStableUInt8, "frames must have uint8 dtype, got ", - frames.dtype()); - TORCH_CHECK( + frames.scalar_type()); + STD_TORCH_CHECK( frames.dim() == 4, "frames must have 4 dimensions (N, C, H, W), got ", frames.dim()); - TORCH_CHECK( + STD_TORCH_CHECK( frames.sizes()[1] == 3, "frame must have 3 channels (R, G, B), got ", frames.sizes()[1]); - return frames.contiguous(); + if (avCodecContext) { + STD_TORCH_CHECK( + static_cast(frames.sizes()[2]) == avCodecContext->height && + static_cast(frames.sizes()[3]) == avCodecContext->width, + "All frames must have the same dimensions. Expected height=", + avCodecContext->height, + " width=", + avCodecContext->width, + ", got height=", + frames.sizes()[2], + " width=", + frames.sizes()[3]); + } + return torch::stable::contiguous(frames); } -} // namespace +AVPixelFormat validatePixelFormat( + const AVCodec& avCodec, + const std::string& targetPixelFormat) { + AVPixelFormat pixelFormat = av_get_pix_fmt(targetPixelFormat.c_str()); + + // Validate that the encoder supports this pixel format + const AVPixelFormat* supportedFormats = getSupportedPixelFormats(avCodec); + if (supportedFormats != nullptr) { + for (int i = 0; supportedFormats[i] != AV_PIX_FMT_NONE; ++i) { + if (supportedFormats[i] == pixelFormat) { + return pixelFormat; + } + } + } -VideoEncoder::~VideoEncoder() { - // TODO-VideoEncoder: Unify destructor with ~AudioEncoder() - if (avFormatContext_ && avFormatContext_->pb) { - if (avFormatContext_->pb->error == 0) { - avio_flush(avFormatContext_->pb); + std::stringstream errorMsg; + // av_get_pix_fmt failed to find a pix_fmt + if (pixelFormat == AV_PIX_FMT_NONE) { + errorMsg << "Unknown pixel format: " << targetPixelFormat; + } else { + errorMsg << "Specified pixel format " << targetPixelFormat + << " is not supported by the " << avCodec.name << " encoder."; + } + // Build error message, similar to FFmpeg's error log + errorMsg << "\nSupported pixel formats for " << avCodec.name << ":"; + for (int i = 0; supportedFormats[i] != AV_PIX_FMT_NONE; ++i) { + errorMsg << " " << av_get_pix_fmt_name(supportedFormats[i]); + } + STD_TORCH_CHECK(false, errorMsg.str()); +} + +void tryToValidateCodecOption( + const AVCodec& avCodec, + const char* optionName, + const std::string& value) { + if (!avCodec.priv_class) { + return; + } + const AVOption* option = av_opt_find2( + // Convert obj arg from const AVClass* const* to non-const void* + // First cast to remove const, then cast to void* + const_cast(static_cast(&avCodec.priv_class)), + optionName, + nullptr, + 0, + AV_OPT_SEARCH_FAKE_OBJ, + nullptr); + // If option is not found we cannot validate it, let FFmpeg handle it + if (!option) { + return; + } + // Validate if option is defined as a numeric type + if (option->type == AV_OPT_TYPE_INT || option->type == AV_OPT_TYPE_INT64 || + option->type == AV_OPT_TYPE_FLOAT || option->type == AV_OPT_TYPE_DOUBLE) { + try { + double numericValue = std::stod(value); + STD_TORCH_CHECK( + numericValue >= option->min && numericValue <= option->max, + optionName, + "=", + numericValue, + " is out of valid range [", + option->min, + ", ", + option->max, + "] for this codec. For more details, run 'ffmpeg -h encoder=", + avCodec.name, + "'"); + } catch (const std::invalid_argument&) { + STD_TORCH_CHECK( + false, + "Option ", + optionName, + " expects a numeric value but got '", + value, + "'"); } - if (!avioContextHolder_) { - if (avFormatContext_->pb->error == 0) { - avio_close(avFormatContext_->pb); - } - avFormatContext_->pb = nullptr; + } +} + +void sortCodecOptions( + const AVFormatContext* avFormatContext, + const std::map& extraOptions, + UniqueAVDictionary& codecDict, + UniqueAVDictionary& formatDict) { + // Accepts a map of options as input, then sorts them into codec options and + // format options. The sorted options are returned into two separate dicts. + const AVClass* formatClass = avformat_get_class(); + const AVClass* muxerClass = + avFormatContext->oformat ? avFormatContext->oformat->priv_class : nullptr; + for (const auto& [key, value] : extraOptions) { + // Check if option is generic format option + const AVOption* fmtOpt = av_opt_find2( + &formatClass, + key.c_str(), + nullptr, + 0, + AV_OPT_SEARCH_CHILDREN | AV_OPT_SEARCH_FAKE_OBJ, + nullptr); + // Check if option is muxer-specific option + // (Returned from `ffmpeg -h muxer=mp4`) + const AVOption* muxerOpt = nullptr; + if (muxerClass) { + muxerOpt = av_opt_find2( + &muxerClass, + key.c_str(), + nullptr, + 0, + AV_OPT_SEARCH_FAKE_OBJ, + nullptr); + } + if (fmtOpt || muxerOpt) { + // Pass container-format options to formatDict to be used in + // avformat_write_header + av_dict_set(formatDict.getAddress(), key.c_str(), value.c_str(), 0); + } else { + // By default, pass as codec option to be used in avcodec_open2 + av_dict_set(codecDict.getAddress(), key.c_str(), value.c_str(), 0); } } } +} // namespace + +VideoEncoder::~VideoEncoder() { + closeAVIOContext(avFormatContext_.get(), avioContextHolder_.get()); +} VideoEncoder::VideoEncoder( - const torch::Tensor& frames, - int frameRate, + const torch::stable::Tensor& frames, + double frameRate, std::string_view fileName, const VideoStreamOptions& videoStreamOptions) : frames_(validateFrames(frames)), inFrameRate_(frameRate) { @@ -564,7 +709,7 @@ VideoEncoder::VideoEncoder( int status = avformat_alloc_output_context2( &avFormatContext, nullptr, nullptr, fileName.data()); - TORCH_CHECK( + STD_TORCH_CHECK( avFormatContext != nullptr, "Couldn't allocate AVFormatContext. ", "The destination file is ", @@ -574,7 +719,7 @@ VideoEncoder::VideoEncoder( avFormatContext_.reset(avFormatContext); status = avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE); - TORCH_CHECK( + STD_TORCH_CHECK( status >= 0, "avio_open failed. The destination file is ", fileName, @@ -584,8 +729,8 @@ VideoEncoder::VideoEncoder( } VideoEncoder::VideoEncoder( - const torch::Tensor& frames, - int frameRate, + const torch::stable::Tensor& frames, + double frameRate, std::string_view formatName, std::unique_ptr avioContextHolder, const VideoStreamOptions& videoStreamOptions) @@ -599,7 +744,7 @@ VideoEncoder::VideoEncoder( int status = avformat_alloc_output_context2( &avFormatContext, nullptr, formatName.data(), nullptr); - TORCH_CHECK( + STD_TORCH_CHECK( avFormatContext != nullptr, "Couldn't allocate AVFormatContext. ", "Check the desired format? Got format=", @@ -615,44 +760,100 @@ VideoEncoder::VideoEncoder( void VideoEncoder::initializeEncoder( const VideoStreamOptions& videoStreamOptions) { - const AVCodec* avCodec = - avcodec_find_encoder(avFormatContext_->oformat->video_codec); - TORCH_CHECK(avCodec != nullptr, "Video codec not found"); + auto tensorDevice = frames_.device(); + deviceInterface_ = createDeviceInterface(StableDevice( + static_cast(tensorDevice.type()), + tensorDevice.index())); + const AVCodec* avCodec = nullptr; + // If codec arg is provided, find codec using logic similar to FFmpeg: + // https://github.com/FFmpeg/FFmpeg/blob/master/fftools/ffmpeg_opt.c#L804-L835 + if (videoStreamOptions.codec.has_value()) { + const std::string& codec = videoStreamOptions.codec.value(); + // Try to find codec by name ("libx264", "libsvtav1") + avCodec = avcodec_find_encoder_by_name(codec.c_str()); + // Try to find by codec descriptor ("h264", "av1") + if (!avCodec) { + const AVCodecDescriptor* desc = + avcodec_descriptor_get_by_name(codec.c_str()); + if (desc) { + avCodec = avcodec_find_encoder(desc->id); + } + } + } else { + STD_TORCH_CHECK( + avFormatContext_->oformat != nullptr, + "Output format is null, unable to find default codec."); + // Try to substitute the default codec with its hardware equivalent + // This will return std::nullopt when device is CPU. + auto hwCodec = deviceInterface_->findCodec( + avFormatContext_->oformat->video_codec, /*isDecoder=*/false); + if (hwCodec.has_value()) { + avCodec = hwCodec.value(); + } + if (!avCodec) { + avCodec = avcodec_find_encoder(avFormatContext_->oformat->video_codec); + } + } + STD_TORCH_CHECK( + avCodec != nullptr, + "Video codec ", + videoStreamOptions.codec.has_value() + ? videoStreamOptions.codec.value() + " " + : "", + "not found. To see available codecs, run: ffmpeg -encoders"); AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec); - TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context."); + STD_TORCH_CHECK( + avCodecContext != nullptr, "Couldn't allocate codec context."); avCodecContext_.reset(avCodecContext); - // Store dimension order and input pixel format - // TODO-VideoEncoder: Remove assumption that tensor in NCHW format + // Store dimensions of input frames + // TODO-VideoEncoder: (P2) Enable tensors in NHWC shape auto sizes = frames_.sizes(); - inPixelFormat_ = AV_PIX_FMT_GBRP; - inHeight_ = static_cast(sizes[2]); - inWidth_ = static_cast(sizes[3]); - - // Use specified dimensions or input dimensions - // TODO-VideoEncoder: Allow height and width to be set - outWidth_ = inWidth_; - outHeight_ = inHeight_; - - // TODO-VideoEncoder: Enable other pixel formats - // Let FFmpeg choose best pixel format to minimize loss - outPixelFormat_ = avcodec_find_best_pix_fmt_of_list( - getSupportedPixelFormats(*avCodec), // List of supported formats - AV_PIX_FMT_GBRP, // We reorder input to GBRP currently - 0, // No alpha channel - nullptr // Discard conversion loss information - ); - TORCH_CHECK(outPixelFormat_ != -1, "Failed to find best pix fmt") + int inHeight = static_cast(sizes[2]); + int inWidth = static_cast(sizes[3]); + + // Always use input dimensions as output dimensions + // TODO-VideoEncoder: (P2) Allow height and width to be set + int outWidth = inWidth; + int outHeight = inHeight; + AVPixelFormat outPixelFormat = AV_PIX_FMT_NONE; + + if (videoStreamOptions.pixelFormat.has_value()) { + // TODO-VideoEncoder: (P2) Enable pixel formats to be set by user on GPU + // and handled with the appropriate NPP function on GPU. + if (frames_.device().type() == kStableCUDA) { + STD_TORCH_CHECK( + false, + "Video encoding on GPU currently only supports the nv12 pixel format. " + "Do not set pixel_format to use nv12 by default."); + } + outPixelFormat = + validatePixelFormat(*avCodec, videoStreamOptions.pixelFormat.value()); + } else { + if (frames_.device().type() == kStableCUDA) { + // Default to nv12 pixel format when encoding on GPU. + outPixelFormat = DeviceInterface::CUDA_ENCODING_PIXEL_FORMAT; + } else { + const AVPixelFormat* formats = getSupportedPixelFormats(*avCodec); + // Use first listed pixel format as default (often yuv420p). + // This is similar to FFmpeg's logic: + // https://www.ffmpeg.org/doxygen/4.0/decode_8c_source.html#l01087 + // If pixel formats are undefined for some reason, try yuv420p + outPixelFormat = (formats && formats[0] != AV_PIX_FMT_NONE) + ? formats[0] + : AV_PIX_FMT_YUV420P; + } + } // Configure codec parameters avCodecContext_->codec_id = avCodec->id; - avCodecContext_->width = outWidth_; - avCodecContext_->height = outHeight_; - avCodecContext_->pix_fmt = outPixelFormat_; - // TODO-VideoEncoder: Verify that frame_rate and time_base are correct - avCodecContext_->time_base = {1, inFrameRate_}; - avCodecContext_->framerate = {inFrameRate_, 1}; + avCodecContext_->width = outWidth; + avCodecContext_->height = outHeight; + avCodecContext_->pix_fmt = outPixelFormat; + // TODO-VideoEncoder: (P1) Add and utilize output frame_rate option + avCodecContext_->framerate = av_d2q(inFrameRate_, INT_MAX); + avCodecContext_->time_base = av_inv_q(avCodecContext_->framerate); // Set flag for containers that require extradata to be in the codec context if (avFormatContext_->oformat->flags & AVFMT_GLOBALHEADER) { @@ -660,30 +861,57 @@ void VideoEncoder::initializeEncoder( } // Apply videoStreamOptions - AVDictionary* options = nullptr; + UniqueAVDictionary avCodecOptions; + if (videoStreamOptions.extraOptions.has_value()) { + for (const auto& [key, value] : videoStreamOptions.extraOptions.value()) { + tryToValidateCodecOption(*avCodec, key.c_str(), value); + } + sortCodecOptions( + avFormatContext_.get(), + videoStreamOptions.extraOptions.value(), + avCodecOptions, + avFormatOptions_); + } + if (videoStreamOptions.crf.has_value()) { + std::string crfValue = std::to_string(videoStreamOptions.crf.value()); + tryToValidateCodecOption(*avCodec, "crf", crfValue); + av_dict_set(avCodecOptions.getAddress(), "crf", crfValue.c_str(), 0); + } + if (videoStreamOptions.preset.has_value()) { av_dict_set( - &options, - "crf", - std::to_string(videoStreamOptions.crf.value()).c_str(), + avCodecOptions.getAddress(), + "preset", + videoStreamOptions.preset.value().c_str(), 0); } - int status = avcodec_open2(avCodecContext_.get(), avCodec, &options); - av_dict_free(&options); - TORCH_CHECK( + if (frames_.device().type() == kStableCUDA) { + deviceInterface_->registerHardwareDeviceWithCodec(avCodecContext_.get()); + deviceInterface_->setupHardwareFrameContextForEncoding( + avCodecContext_.get()); + } + + int status = avcodec_open2( + avCodecContext_.get(), avCodec, avCodecOptions.getAddress()); + + STD_TORCH_CHECK( status == AVSUCCESS, "avcodec_open2 failed: ", getFFMPEGErrorStringFromErrorCode(status)); avStream_ = avformat_new_stream(avFormatContext_.get(), nullptr); - TORCH_CHECK(avStream_ != nullptr, "Couldn't create new stream."); + STD_TORCH_CHECK(avStream_ != nullptr, "Couldn't create new stream."); // Set the stream time base to encode correct frame timestamps avStream_->time_base = avCodecContext_->time_base; + // Set the stream frame rate to store correct frame durations for some + // containers (webm, mkv) + avStream_->r_frame_rate = avCodecContext_->framerate; + status = avcodec_parameters_from_context( avStream_->codecpar, avCodecContext_.get()); - TORCH_CHECK( + STD_TORCH_CHECK( status == AVSUCCESS, "avcodec_parameters_from_context failed: ", getFFMPEGErrorStringFromErrorCode(status)); @@ -691,11 +919,12 @@ void VideoEncoder::initializeEncoder( void VideoEncoder::encode() { // To be on the safe side we enforce that encode() can only be called once - TORCH_CHECK(!encodeWasCalled_, "Cannot call encode() twice."); + STD_TORCH_CHECK(!encodeWasCalled_, "Cannot call encode() twice."); encodeWasCalled_ = true; - int status = avformat_write_header(avFormatContext_.get(), nullptr); - TORCH_CHECK( + int status = avformat_write_header( + avFormatContext_.get(), avFormatOptions_.getAddress()); + STD_TORCH_CHECK( status == AVSUCCESS, "Error in avformat_write_header: ", getFFMPEGErrorStringFromErrorCode(status)); @@ -703,102 +932,328 @@ void VideoEncoder::encode() { AutoAVPacket autoAVPacket; int numFrames = static_cast(frames_.sizes()[0]); for (int i = 0; i < numFrames; ++i) { - torch::Tensor currFrame = frames_[i]; - UniqueAVFrame avFrame = convertTensorToAVFrame(currFrame, i); + torch::stable::Tensor currFrame = selectRow(frames_, i); + UniqueAVFrame avFrame = deviceInterface_->convertTensorToAVFrameForEncoding( + currFrame, i, avCodecContext_.get()); + STD_TORCH_CHECK( + avFrame != nullptr, + "convertTensorToAVFrameForEncoding failed for frame ", + i, + " on device: ", + deviceTypeName(frames_.device().type())); encodeFrame(autoAVPacket, avFrame); } flushBuffers(); status = av_write_trailer(avFormatContext_.get()); - TORCH_CHECK( + // av_write_trailer returns mfra atom size (positive) for fragmented + // containers, which we'd misinterpret as an error, since all FFmpeg errors + // are negative (see AVERROR definition: + // http://ffmpeg.org/doxygen/8.0/error_8h_source.html) So we replace positive + // values with AVSUCCESS. See: + // https://github.com/FFmpeg/FFmpeg/blob/n8.0/libavformat/movenc.c#L8666 + if (status > 0) { + status = AVSUCCESS; + } + STD_TORCH_CHECK( status == AVSUCCESS, "Error in av_write_trailer: ", getFFMPEGErrorStringFromErrorCode(status)); } -UniqueAVFrame VideoEncoder::convertTensorToAVFrame( - const torch::Tensor& frame, - int frameIndex) { - // Initialize and cache scaling context if it does not exist - if (!swsContext_) { - swsContext_.reset(sws_getContext( - inWidth_, - inHeight_, - inPixelFormat_, - outWidth_, - outHeight_, - outPixelFormat_, - SWS_BICUBIC, // Used by FFmpeg CLI - nullptr, - nullptr, - nullptr)); - TORCH_CHECK(swsContext_ != nullptr, "Failed to create scaling context"); +torch::stable::Tensor VideoEncoder::encodeToTensor() { + STD_TORCH_CHECK( + avioContextHolder_ != nullptr, + "Cannot encode to tensor, avio tensor context doesn't exist."); + encode(); + auto avioToTensorContext = + dynamic_cast(avioContextHolder_.get()); + STD_TORCH_CHECK( + avioToTensorContext != nullptr, "Invalid AVIO context holder."); + return avioToTensorContext->getOutputTensor(); +} + +void VideoEncoder::encodeFrame( + AutoAVPacket& autoAVPacket, + const UniqueAVFrame& avFrame) { + auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get()); + STD_TORCH_CHECK( + status == AVSUCCESS, + "Error while sending frame: ", + getFFMPEGErrorStringFromErrorCode(status)); + + while (status >= 0) { + ReferenceAVPacket packet(autoAVPacket); + status = avcodec_receive_packet(avCodecContext_.get(), packet.get()); + if (status == AVERROR(EAGAIN) || status == AVERROR_EOF) { + if (status == AVERROR_EOF) { + // Flush remaining buffered packets + status = av_interleaved_write_frame(avFormatContext_.get(), nullptr); + STD_TORCH_CHECK( + status == AVSUCCESS, + "Failed to flush packet: ", + getFFMPEGErrorStringFromErrorCode(status)); + } + return; + } + STD_TORCH_CHECK( + status >= 0, + "Error receiving packet: ", + getFFMPEGErrorStringFromErrorCode(status)); + + // The code below is borrowed from torchaudio: + // https://github.com/pytorch/audio/blob/b6a3368a45aaafe05f1a6a9f10c68adc5e944d9e/src/libtorio/ffmpeg/stream_writer/encoder.cpp#L46 + // Setting packet->duration to 1 allows the last frame to be properly + // encoded, and needs to be set before calling av_packet_rescale_ts. + if (packet->duration == 0) { + packet->duration = 1; + } + av_packet_rescale_ts( + packet.get(), avCodecContext_->time_base, avStream_->time_base); + packet->stream_index = avStream_->index; + + status = av_interleaved_write_frame(avFormatContext_.get(), packet.get()); + STD_TORCH_CHECK( + status == AVSUCCESS, + "Error in av_interleaved_write_frame: ", + getFFMPEGErrorStringFromErrorCode(status)); } +} + +void VideoEncoder::flushBuffers() { + AutoAVPacket autoAVPacket; + // Send null frame to signal end of input + encodeFrame(autoAVPacket, UniqueAVFrame(nullptr)); +} + +MultiStreamEncoder::~MultiStreamEncoder() { + close(); +} + +MultiStreamEncoder::MultiStreamEncoder(std::string_view fileName) { + setFFmpegLogLevel(); + + AVFormatContext* avFormatContext = nullptr; + int status = avformat_alloc_output_context2( + &avFormatContext, nullptr, nullptr, fileName.data()); + + STD_TORCH_CHECK( + avFormatContext != nullptr, + "Couldn't allocate AVFormatContext. ", + "The destination file is ", + fileName, + ", check the desired extension? ", + getFFMPEGErrorStringFromErrorCode(status)); + avFormatContext_.reset(avFormatContext); - UniqueAVFrame avFrame(av_frame_alloc()); - TORCH_CHECK(avFrame != nullptr, "Failed to allocate AVFrame"); + status = avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE); + STD_TORCH_CHECK( + status >= 0, + "avio_open failed. The destination file is ", + fileName, + ", make sure it's a valid path? ", + getFFMPEGErrorStringFromErrorCode(status)); +} - // Set output frame properties - avFrame->format = outPixelFormat_; - avFrame->width = outWidth_; - avFrame->height = outHeight_; - avFrame->pts = frameIndex; +void MultiStreamEncoder::addVideoStream( + double frameRate, + std::optional codec, + std::optional pixelFormat, + std::optional crf, + std::optional preset, + std::optional> extraOptions) { + STD_TORCH_CHECK( + inFrameRate_ == 0, + "A video stream has already been added. Cannot add another."); + STD_TORCH_CHECK(frameRate > 0, "frame_rate must be > 0, got ", frameRate); + inFrameRate_ = frameRate; + videoStreamOptions_.codec = std::move(codec); + videoStreamOptions_.pixelFormat = std::move(pixelFormat); + videoStreamOptions_.crf = crf; + videoStreamOptions_.preset = std::move(preset); + videoStreamOptions_.extraOptions = std::move(extraOptions); +} - int status = av_frame_get_buffer(avFrame.get(), 0); - TORCH_CHECK(status >= 0, "Failed to allocate frame buffer"); +void MultiStreamEncoder::initializeVideoStream( + const torch::stable::Tensor& frames) { + auto tensorDevice = frames.device(); + // TODO MultiStreamEncoder: Enable CUDA support + STD_TORCH_CHECK( + tensorDevice.is_cpu(), "Only CPU tensors are supported for encoding."); + deviceInterface_ = createDeviceInterface(StableDevice( + static_cast(tensorDevice.type()), + tensorDevice.index())); + const AVCodec* avCodec = nullptr; + // If codec arg is provided, find codec using logic similar to FFmpeg: + // https://github.com/FFmpeg/FFmpeg/blob/master/fftools/ffmpeg_opt.c#L804-L835 + if (videoStreamOptions_.codec.has_value()) { + const std::string& codec = videoStreamOptions_.codec.value(); + // Try to find codec by name ("libx264", "libsvtav1") + avCodec = avcodec_find_encoder_by_name(codec.c_str()); + // Try to find by codec descriptor ("h264", "av1") + if (!avCodec) { + const AVCodecDescriptor* desc = + avcodec_descriptor_get_by_name(codec.c_str()); + if (desc) { + avCodec = avcodec_find_encoder(desc->id); + } + } + } else { + STD_TORCH_CHECK( + avFormatContext_->oformat != nullptr, + "Output format is null, unable to find default codec."); + // TODO MultiStreamEncoder: When CUDA support is enabled, substitute codec + // with hardware equivalent + avCodec = avcodec_find_encoder(avFormatContext_->oformat->video_codec); + } + STD_TORCH_CHECK( + avCodec != nullptr, + "Video codec ", + videoStreamOptions_.codec.has_value() + ? videoStreamOptions_.codec.value() + " " + : "", + "not found. To see available codecs, run: ffmpeg -encoders"); - // Need to convert/scale the frame - // Create temporary frame with input format - UniqueAVFrame inputFrame(av_frame_alloc()); - TORCH_CHECK(inputFrame != nullptr, "Failed to allocate input AVFrame"); + AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec); + STD_TORCH_CHECK( + avCodecContext != nullptr, "Couldn't allocate codec context."); + avCodecContext_.reset(avCodecContext); - inputFrame->format = inPixelFormat_; - inputFrame->width = inWidth_; - inputFrame->height = inHeight_; + // Store dimensions of input frames + // TODO MultiStreamEncoder: Enable tensors in NHWC shape + auto sizes = frames.sizes(); + int inHeight = static_cast(sizes[2]); + int inWidth = static_cast(sizes[3]); + + // Always use input dimensions as output dimensions + // TODO MultiStreamEncoder: Allow height and width to be set + int outWidth = inWidth; + int outHeight = inHeight; + AVPixelFormat outPixelFormat = AV_PIX_FMT_NONE; + if (videoStreamOptions_.pixelFormat.has_value()) { + outPixelFormat = + validatePixelFormat(*avCodec, videoStreamOptions_.pixelFormat.value()); + } else { + const AVPixelFormat* formats = getSupportedPixelFormats(*avCodec); + // Use first listed pixel format as default (often yuv420p). + // This is similar to FFmpeg's logic: + // https://www.ffmpeg.org/doxygen/4.0/decode_8c_source.html#l01087 + // If pixel formats are undefined for some reason, try yuv420p + outPixelFormat = (formats && formats[0] != AV_PIX_FMT_NONE) + ? formats[0] + : AV_PIX_FMT_YUV420P; + } - uint8_t* tensorData = static_cast(frame.data_ptr()); + // Configure codec parameters + avCodecContext_->codec_id = avCodec->id; + avCodecContext_->width = outWidth; + avCodecContext_->height = outHeight; + avCodecContext_->pix_fmt = outPixelFormat; + // TODO MultiStreamEncoder: Add and utilize output frame_rate option + avCodecContext_->framerate = av_d2q(inFrameRate_, INT_MAX); + avCodecContext_->time_base = av_inv_q(avCodecContext_->framerate); - // TODO-VideoEncoder: Reorder tensor if in NHWC format - int channelSize = inHeight_ * inWidth_; - // Reorder RGB -> GBR for AV_PIX_FMT_GBRP format - // TODO-VideoEncoder: Determine if FFmpeg supports planar RGB input format - inputFrame->data[0] = tensorData + channelSize; - inputFrame->data[1] = tensorData + (2 * channelSize); - inputFrame->data[2] = tensorData; + // Set flag for containers that require extradata to be in the codec context + if (avFormatContext_->oformat->flags & AVFMT_GLOBALHEADER) { + avCodecContext_->flags |= AV_CODEC_FLAG_GLOBAL_HEADER; + } - inputFrame->linesize[0] = inWidth_; - inputFrame->linesize[1] = inWidth_; - inputFrame->linesize[2] = inWidth_; + // Apply videoStreamOptions + UniqueAVDictionary avCodecOptions; + if (videoStreamOptions_.extraOptions.has_value()) { + for (const auto& [key, value] : videoStreamOptions_.extraOptions.value()) { + tryToValidateCodecOption(*avCodec, key.c_str(), value); + } + sortCodecOptions( + avFormatContext_.get(), + videoStreamOptions_.extraOptions.value(), + avCodecOptions, + avFormatOptions_); + } - status = sws_scale( - swsContext_.get(), - inputFrame->data, - inputFrame->linesize, - 0, - inputFrame->height, - avFrame->data, - avFrame->linesize); - TORCH_CHECK(status == outHeight_, "sws_scale failed"); - return avFrame; + if (videoStreamOptions_.crf.has_value()) { + std::string crfValue = std::to_string(videoStreamOptions_.crf.value()); + tryToValidateCodecOption(*avCodec, "crf", crfValue); + av_dict_set(avCodecOptions.getAddress(), "crf", crfValue.c_str(), 0); + } + + if (videoStreamOptions_.preset.has_value()) { + av_dict_set( + avCodecOptions.getAddress(), + "preset", + videoStreamOptions_.preset.value().c_str(), + 0); + } + + int status = avcodec_open2( + avCodecContext_.get(), avCodec, avCodecOptions.getAddress()); + + STD_TORCH_CHECK( + status == AVSUCCESS, + "avcodec_open2 failed: ", + getFFMPEGErrorStringFromErrorCode(status)); + + avStream_ = avformat_new_stream(avFormatContext_.get(), nullptr); + STD_TORCH_CHECK(avStream_ != nullptr, "Couldn't create new stream."); + + // Set the stream time base to encode correct frame timestamps + avStream_->time_base = avCodecContext_->time_base; + // Set the stream frame rate to store correct frame durations for some + // containers (webm, mkv) + avStream_->r_frame_rate = avCodecContext_->framerate; + + status = avcodec_parameters_from_context( + avStream_->codecpar, avCodecContext_.get()); + STD_TORCH_CHECK( + status == AVSUCCESS, + "avcodec_parameters_from_context failed: ", + getFFMPEGErrorStringFromErrorCode(status)); + + status = avformat_write_header( + avFormatContext_.get(), avFormatOptions_.getAddress()); + STD_TORCH_CHECK( + status == AVSUCCESS, + "Error in avformat_write_header: ", + getFFMPEGErrorStringFromErrorCode(status)); + headerWritten_ = true; } -torch::Tensor VideoEncoder::encodeToTensor() { - TORCH_CHECK( - avioContextHolder_ != nullptr, - "Cannot encode to tensor, avio tensor context doesn't exist."); - encode(); - auto avioToTensorContext = - dynamic_cast(avioContextHolder_.get()); - TORCH_CHECK(avioToTensorContext != nullptr, "Invalid AVIO context holder."); - return avioToTensorContext->getOutputTensor(); +void MultiStreamEncoder::addFrames(const torch::stable::Tensor& frames) { + STD_TORCH_CHECK( + inFrameRate_ > 0, + "No video stream has been added. Call addVideoStream() first."); + auto validatedFrames = validateFrames(frames, avCodecContext_.get()); + + if (!headerWritten_) { + initializeVideoStream(validatedFrames); + } + + AutoAVPacket autoAVPacket; + // TODO MultiStreamEncoder: Consider using accessor for potential performance + // improvement + int numFrames = static_cast(validatedFrames.sizes()[0]); + for (int i = 0; i < numFrames; ++i) { + torch::stable::Tensor currFrame = selectRow(validatedFrames, i); + int frameIndex = numEncodedFrames_ + i; + UniqueAVFrame avFrame = deviceInterface_->convertTensorToAVFrameForEncoding( + currFrame, frameIndex, avCodecContext_.get()); + STD_TORCH_CHECK( + avFrame != nullptr, + "convertTensorToAVFrameForEncoding failed for frame ", + frameIndex, + " on device: ", + deviceTypeName(validatedFrames.device().type())); + encodeFrame(autoAVPacket, avFrame); + } + numEncodedFrames_ += numFrames; } -void VideoEncoder::encodeFrame( +void MultiStreamEncoder::encodeFrame( AutoAVPacket& autoAVPacket, const UniqueAVFrame& avFrame) { auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get()); - TORCH_CHECK( + STD_TORCH_CHECK( status == AVSUCCESS, "Error while sending frame: ", getFFMPEGErrorStringFromErrorCode(status)); @@ -810,14 +1265,14 @@ void VideoEncoder::encodeFrame( if (status == AVERROR_EOF) { // Flush remaining buffered packets status = av_interleaved_write_frame(avFormatContext_.get(), nullptr); - TORCH_CHECK( + STD_TORCH_CHECK( status == AVSUCCESS, "Failed to flush packet: ", getFFMPEGErrorStringFromErrorCode(status)); } return; } - TORCH_CHECK( + STD_TORCH_CHECK( status >= 0, "Error receiving packet: ", getFFMPEGErrorStringFromErrorCode(status)); @@ -834,17 +1289,42 @@ void VideoEncoder::encodeFrame( packet->stream_index = avStream_->index; status = av_interleaved_write_frame(avFormatContext_.get(), packet.get()); - TORCH_CHECK( + STD_TORCH_CHECK( status == AVSUCCESS, "Error in av_interleaved_write_frame: ", getFFMPEGErrorStringFromErrorCode(status)); } } -void VideoEncoder::flushBuffers() { +void MultiStreamEncoder::flushBuffers() { AutoAVPacket autoAVPacket; // Send null frame to signal end of input encodeFrame(autoAVPacket, UniqueAVFrame(nullptr)); } +void MultiStreamEncoder::close() { + if (closed_) { + return; + } + // TODO MultiStreamEncoder: Revisit if "closed_" flag is useful + closed_ = true; + + if (headerWritten_) { + flushBuffers(); + + int status = av_write_trailer(avFormatContext_.get()); + // av_write_trailer returns mfra atom size (positive) for fragmented + // containers. All FFmpeg errors are negative, so positive is not an error. + if (status > 0) { + status = AVSUCCESS; + } + STD_TORCH_CHECK( + status == AVSUCCESS, + "Error in av_write_trailer: ", + getFFMPEGErrorStringFromErrorCode(status)); + } + + closeAVIOContext(avFormatContext_.get(), avioContextHolder_.get()); +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index 168591616..165e8d8d2 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -1,22 +1,29 @@ #pragma once -#include -#include "src/torchcodec/_core/AVIOContextHolder.h" -#include "src/torchcodec/_core/FFMPEGCommon.h" -#include "src/torchcodec/_core/StreamOptions.h" +#include +#include +#include "AVIOContextHolder.h" +#include "DeviceInterface.h" +#include "FFMPEGCommon.h" +#include "StableABICompat.h" +#include "StreamOptions.h" + +extern "C" { +#include +} namespace facebook::torchcodec { -class AudioEncoder { +class FORCE_PUBLIC_VISIBILITY AudioEncoder { public: ~AudioEncoder(); AudioEncoder( - const torch::Tensor& samples, + const torch::stable::Tensor& samples, int sampleRate, std::string_view fileName, const AudioStreamOptions& audioStreamOptions); AudioEncoder( - const torch::Tensor& samples, + const torch::stable::Tensor& samples, int sampleRate, std::string_view formatName, std::unique_ptr avioContextHolder, @@ -24,7 +31,7 @@ class AudioEncoder { void encode(); - torch::Tensor encodeToTensor(); + torch::stable::Tensor encodeToTensor(); private: void initializeEncoder(const AudioStreamOptions& audioStreamOptions); @@ -36,7 +43,6 @@ class AudioEncoder { void encodeFrame(AutoAVPacket& autoAVPacket, const UniqueAVFrame& avFrame); void maybeFlushSwrBuffers(AutoAVPacket& autoAVPacket); void flushBuffers(); - void close_avio(); UniqueEncodingAVFormatContext avFormatContext_; UniqueAVCodecContext avCodecContext_; @@ -44,7 +50,7 @@ class AudioEncoder { UniqueSwrContext swrContext_; AudioStreamOptions audioStreamOptions; - const torch::Tensor samples_; + const torch::stable::Tensor samples_; int outNumChannels_ = -1; int outSampleRate_ = -1; @@ -121,7 +127,7 @@ class AudioEncoder { // /* clang-format on */ -class VideoEncoder { +class FORCE_PUBLIC_VISIBILITY VideoEncoder { public: ~VideoEncoder(); @@ -131,54 +137,85 @@ class VideoEncoder { VideoEncoder(const VideoEncoder&) = delete; VideoEncoder& operator=(const VideoEncoder&) = delete; - // Move assignment operator deleted since we have a const member - VideoEncoder(VideoEncoder&&) = default; + // Move operators deleted since UniqueAVDictionary member is not movable + VideoEncoder(VideoEncoder&&) = delete; VideoEncoder& operator=(VideoEncoder&&) = delete; VideoEncoder( - const torch::Tensor& frames, - int frameRate, + const torch::stable::Tensor& frames, + double frameRate, std::string_view fileName, const VideoStreamOptions& videoStreamOptions); VideoEncoder( - const torch::Tensor& frames, - int frameRate, + const torch::stable::Tensor& frames, + double frameRate, std::string_view formatName, std::unique_ptr avioContextHolder, const VideoStreamOptions& videoStreamOptions); void encode(); - torch::Tensor encodeToTensor(); + torch::stable::Tensor encodeToTensor(); private: void initializeEncoder(const VideoStreamOptions& videoStreamOptions); - UniqueAVFrame convertTensorToAVFrame( - const torch::Tensor& frame, - int frameIndex); void encodeFrame(AutoAVPacket& autoAVPacket, const UniqueAVFrame& avFrame); void flushBuffers(); UniqueEncodingAVFormatContext avFormatContext_; UniqueAVCodecContext avCodecContext_; AVStream* avStream_ = nullptr; - UniqueSwsContext swsContext_; - const torch::Tensor frames_; - int inFrameRate_; - - int inWidth_ = -1; - int inHeight_ = -1; - AVPixelFormat inPixelFormat_ = AV_PIX_FMT_NONE; - - int outWidth_ = -1; - int outHeight_ = -1; - AVPixelFormat outPixelFormat_ = AV_PIX_FMT_NONE; + const torch::stable::Tensor frames_; + double inFrameRate_; std::unique_ptr avioContextHolder_; + std::unique_ptr deviceInterface_; bool encodeWasCalled_ = false; + UniqueAVDictionary avFormatOptions_; +}; + +class FORCE_PUBLIC_VISIBILITY MultiStreamEncoder { + public: + ~MultiStreamEncoder(); + + MultiStreamEncoder(const MultiStreamEncoder&) = delete; + MultiStreamEncoder& operator=(const MultiStreamEncoder&) = delete; + MultiStreamEncoder(MultiStreamEncoder&&) = delete; + MultiStreamEncoder& operator=(MultiStreamEncoder&&) = delete; + + MultiStreamEncoder(std::string_view fileName); + + void addVideoStream( + double frameRate, + std::optional codec = std::nullopt, + std::optional pixelFormat = std::nullopt, + std::optional crf = std::nullopt, + std::optional preset = std::nullopt, + std::optional> extraOptions = + std::nullopt); + void addFrames(const torch::stable::Tensor& frames); + void close(); + + private: + void initializeVideoStream(const torch::stable::Tensor& frames); + void encodeFrame(AutoAVPacket& autoAVPacket, const UniqueAVFrame& avFrame); + void flushBuffers(); + + UniqueEncodingAVFormatContext avFormatContext_; + UniqueAVCodecContext avCodecContext_; + AVStream* avStream_ = nullptr; + double inFrameRate_ = 0; + VideoStreamOptions videoStreamOptions_; + std::unique_ptr deviceInterface_; + bool headerWritten_ = false; + int numEncodedFrames_ = 0; + UniqueAVDictionary avFormatOptions_; + + std::unique_ptr avioContextHolder_; + bool closed_ = false; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/FFMPEGCommon.cpp b/src/torchcodec/_core/FFMPEGCommon.cpp index b9663d8d2..e7b3efa35 100644 --- a/src/torchcodec/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/_core/FFMPEGCommon.cpp @@ -4,9 +4,9 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. -#include "src/torchcodec/_core/FFMPEGCommon.h" +#include "FFMPEGCommon.h" -#include +#include "StableABICompat.h" extern "C" { #include @@ -16,7 +16,7 @@ extern "C" { namespace facebook::torchcodec { AutoAVPacket::AutoAVPacket() : avPacket_(av_packet_alloc()) { - TORCH_CHECK(avPacket_ != nullptr, "Couldn't allocate avPacket."); + STD_TORCH_CHECK(avPacket_ != nullptr, "Couldn't allocate avPacket."); } AutoAVPacket::~AutoAVPacket() { @@ -102,7 +102,8 @@ const AVPixelFormat* getSupportedPixelFormats(const AVCodec& avCodec) { reinterpret_cast(&supportedPixelFormats), &numPixelFormats); if (ret < 0 || supportedPixelFormats == nullptr) { - TORCH_CHECK(false, "Couldn't get supported pixel formats from encoder."); + STD_TORCH_CHECK( + false, "Couldn't get supported pixel formats from encoder."); } #else supportedPixelFormats = avCodec.pix_fmts; @@ -158,6 +159,16 @@ int getNumChannels(const SharedAVCodecContext& avCodecContext) { #endif } +int getNumChannels(const AVCodecParameters* codecpar) { + STD_TORCH_CHECK(codecpar != nullptr, "codecpar is null"); +#if LIBAVFILTER_VERSION_MAJOR > 8 || \ + (LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44) + return codecpar->ch_layout.nb_channels; +#else + return codecpar->channels; +#endif +} + void setDefaultChannelLayout( UniqueAVCodecContext& avCodecContext, int numChannels) { @@ -254,7 +265,7 @@ void validateNumChannels(const AVCodec& avCodec, int numChannels) { avCodec.channel_layouts[i]); } #endif - TORCH_CHECK( + STD_TORCH_CHECK( false, "Desired number of channels (", numChannels, @@ -308,7 +319,7 @@ void setChannelLayout( AVChannelLayout outLayout = getOutputChannelLayout(outNumChannels, srcAVFrame); auto status = av_channel_layout_copy(&dstAVFrame->ch_layout, &outLayout); - TORCH_CHECK( + STD_TORCH_CHECK( status == AVSUCCESS, "Couldn't copy channel layout to avFrame: ", getFFMPEGErrorStringFromErrorCode(status)); @@ -325,7 +336,7 @@ UniqueAVFrame allocateAVFrame( int numChannels, AVSampleFormat sampleFormat) { auto avFrame = UniqueAVFrame(av_frame_alloc()); - TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame."); + STD_TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame."); avFrame->nb_samples = numSamples; avFrame->sample_rate = sampleRate; @@ -333,13 +344,13 @@ UniqueAVFrame allocateAVFrame( avFrame->format = sampleFormat; auto status = av_frame_get_buffer(avFrame.get(), 0); - TORCH_CHECK( + STD_TORCH_CHECK( status == AVSUCCESS, "Couldn't allocate avFrame's buffers: ", getFFMPEGErrorStringFromErrorCode(status)); status = av_frame_make_writable(avFrame.get()); - TORCH_CHECK( + STD_TORCH_CHECK( status == AVSUCCESS, "Couldn't make AVFrame writable: ", getFFMPEGErrorStringFromErrorCode(status)); @@ -369,7 +380,7 @@ SwrContext* createSwrContext( 0, nullptr); - TORCH_CHECK( + STD_TORCH_CHECK( status == AVSUCCESS, "Couldn't create SwrContext: ", getFFMPEGErrorStringFromErrorCode(status)); @@ -387,9 +398,9 @@ SwrContext* createSwrContext( nullptr); #endif - TORCH_CHECK(swrContext != nullptr, "Couldn't create swrContext"); + STD_TORCH_CHECK(swrContext != nullptr, "Couldn't create swrContext"); status = swr_init(swrContext); - TORCH_CHECK( + STD_TORCH_CHECK( status == AVSUCCESS, "Couldn't initialize SwrContext: ", getFFMPEGErrorStringFromErrorCode(status), @@ -399,68 +410,65 @@ SwrContext* createSwrContext( return swrContext; } -AVFilterContext* createBuffersinkFilter( +AVFilterContext* createAVFilterContextWithOptions( AVFilterGraph* filterGraph, - enum AVPixelFormat outputFormat) { - const AVFilter* buffersink = avfilter_get_by_name("buffersink"); - TORCH_CHECK(buffersink != nullptr, "Failed to get buffersink filter."); - - AVFilterContext* sinkContext = nullptr; - int status; + const AVFilter* buffer, + const enum AVPixelFormat outputFormat) { + AVFilterContext* avFilterContext = nullptr; const char* filterName = "out"; - enum AVPixelFormat pix_fmts[] = {outputFormat, AV_PIX_FMT_NONE}; + enum AVPixelFormat pixFmts[] = {outputFormat, AV_PIX_FMT_NONE}; // av_opt_set_int_list was replaced by av_opt_set_array() in FFmpeg 8. #if LIBAVUTIL_VERSION_MAJOR >= 60 // FFmpeg >= 8 // Output options like pixel_formats must be set before filter init - sinkContext = - avfilter_graph_alloc_filter(filterGraph, buffersink, filterName); - TORCH_CHECK( - sinkContext != nullptr, "Failed to allocate buffersink filter context."); + avFilterContext = + avfilter_graph_alloc_filter(filterGraph, buffer, filterName); + STD_TORCH_CHECK( + avFilterContext != nullptr, "Failed to allocate buffer filter context."); // When setting pix_fmts, only the first element is used, so nb_elems = 1 // AV_PIX_FMT_NONE acts as a terminator for the array in av_opt_set_int_list - status = av_opt_set_array( - sinkContext, + int status = av_opt_set_array( + avFilterContext, "pixel_formats", AV_OPT_SEARCH_CHILDREN, 0, // start_elem 1, // nb_elems AV_OPT_TYPE_PIXEL_FMT, - pix_fmts); - TORCH_CHECK( + pixFmts); + STD_TORCH_CHECK( status >= 0, - "Failed to set pixel format for buffersink filter: ", + "Failed to set pixel format for buffer filter: ", getFFMPEGErrorStringFromErrorCode(status)); - status = avfilter_init_str(sinkContext, nullptr); - TORCH_CHECK( + status = avfilter_init_str(avFilterContext, nullptr); + STD_TORCH_CHECK( status >= 0, - "Failed to initialize buffersink filter: ", + "Failed to initialize buffer filter: ", getFFMPEGErrorStringFromErrorCode(status)); #else // FFmpeg <= 7 // For older FFmpeg versions, create filter and then set options - status = avfilter_graph_create_filter( - &sinkContext, buffersink, filterName, nullptr, nullptr, filterGraph); - TORCH_CHECK( + int status = avfilter_graph_create_filter( + &avFilterContext, buffer, filterName, nullptr, nullptr, filterGraph); + STD_TORCH_CHECK( status >= 0, - "Failed to create buffersink filter: ", + "Failed to create buffer filter: ", getFFMPEGErrorStringFromErrorCode(status)); status = av_opt_set_int_list( - sinkContext, + avFilterContext, "pix_fmts", - pix_fmts, + pixFmts, AV_PIX_FMT_NONE, AV_OPT_SEARCH_CHILDREN); - TORCH_CHECK( + STD_TORCH_CHECK( status >= 0, - "Failed to set pixel formats for buffersink filter: ", + "Failed to set pixel formats for buffer filter: ", getFFMPEGErrorStringFromErrorCode(status)); #endif - return sinkContext; + return avFilterContext; } UniqueAVFrame convertAudioAVFrameSamples( @@ -470,7 +478,7 @@ UniqueAVFrame convertAudioAVFrameSamples( int outSampleRate, int outNumChannels) { UniqueAVFrame convertedAVFrame(av_frame_alloc()); - TORCH_CHECK( + STD_TORCH_CHECK( convertedAVFrame, "Could not allocate frame for sample format conversion."); @@ -500,22 +508,26 @@ UniqueAVFrame convertAudioAVFrameSamples( setChannelLayout(convertedAVFrame, srcAVFrame, outNumChannels); auto status = av_frame_get_buffer(convertedAVFrame.get(), 0); - TORCH_CHECK( + STD_TORCH_CHECK( status == AVSUCCESS, "Could not allocate frame buffers for sample format conversion: ", getFFMPEGErrorStringFromErrorCode(status)); + // Below we use AVFrame->extended_data instead of AVFrame->data to support + // decoding audio with >8 audio channels. extended_data contains pointers + // for all channels, while data only contains AV_NUM_DATA_POINTERS (8). + // https://ffmpeg.org/doxygen/trunk/structAVFrame.html#afca04d808393822625e09b5ba91c6756 auto numConvertedSamples = swr_convert( swrContext.get(), - convertedAVFrame->data, + convertedAVFrame->extended_data, convertedAVFrame->nb_samples, static_cast( - const_cast(srcAVFrame->data)), + const_cast(srcAVFrame->extended_data)), srcAVFrame->nb_samples); // numConvertedSamples can be 0 if we're downsampling by a great factor and // the first frame doesn't contain a lot of samples. It should be handled // properly by the caller. - TORCH_CHECK( + STD_TORCH_CHECK( numConvertedSamples >= 0, "Error in swr_convert: ", getFFMPEGErrorStringFromErrorCode(numConvertedSamples)); @@ -550,7 +562,7 @@ void setFFmpegLogLevel() { } else if (logLevelEnv == "TRACE") { logLevel = AV_LOG_TRACE; } else { - TORCH_CHECK( + STD_TORCH_CHECK( false, "Invalid TORCHCODEC_FFMPEG_LOG_LEVEL: ", logLevelEnv, @@ -605,45 +617,111 @@ int64_t computeSafeDuration( } } -SwsFrameContext::SwsFrameContext( +std::optional getRotationFromStream(const AVStream* avStream) { + // av_stream_get_side_data() was deprecated in FFmpeg 6.0, but its replacement + // (av_packet_side_data_get() + codecpar->coded_side_data) is only available + // from FFmpeg 6.1. We need some #pragma magic to silence the deprecation + // warning which our compile chain would otherwise treat as an error. + if (avStream == nullptr) { + return std::nullopt; + } + + const int32_t* displayMatrix = nullptr; + +// FFmpeg >= 6.1: Use codecpar->coded_side_data +#if LIBAVCODEC_VERSION_INT >= AV_VERSION_INT(60, 31, 100) + const AVPacketSideData* sideData = av_packet_side_data_get( + avStream->codecpar->coded_side_data, + avStream->codecpar->nb_coded_side_data, + AV_PKT_DATA_DISPLAYMATRIX); + if (sideData != nullptr) { + displayMatrix = reinterpret_cast(sideData->data); + } +#elif LIBAVFORMAT_VERSION_MAJOR >= 60 // FFmpeg 6.0 + // FFmpeg 6.0: Use av_stream_get_side_data (deprecated but still available) + // Suppress deprecation warning for this specific call +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" + size_t sideDataSize = 0; + const uint8_t* sideData = av_stream_get_side_data( + avStream, AV_PKT_DATA_DISPLAYMATRIX, &sideDataSize); +#pragma GCC diagnostic pop + if (sideData != nullptr) { + displayMatrix = reinterpret_cast(sideData); + } +#else + // FFmpeg < 6: Use av_stream_get_side_data. + // The size parameter type changed from int* (FFmpeg 4) to size_t* (FFmpeg 5) +#if LIBAVFORMAT_VERSION_MAJOR >= 59 // FFmpeg 5 + size_t sideDataSize = 0; +#else // FFmpeg 4 + int sideDataSize = 0; +#endif + const uint8_t* sideData = av_stream_get_side_data( + avStream, AV_PKT_DATA_DISPLAYMATRIX, &sideDataSize); + if (sideData != nullptr) { + displayMatrix = reinterpret_cast(sideData); + } +#endif + + if (displayMatrix == nullptr) { + return std::nullopt; + } + + // av_display_rotation_get returns the rotation angle in degrees needed to + // rotate the video counter-clockwise to make it upright. + // Returns NaN if the matrix is invalid. + double rotation = av_display_rotation_get(displayMatrix); + + // Check for invalid matrix + if (std::isnan(rotation)) { + return std::nullopt; + } + + return rotation; +} + +SwsConfig::SwsConfig( int inputWidth, int inputHeight, AVPixelFormat inputFormat, + AVColorSpace inputColorspace, int outputWidth, int outputHeight) : inputWidth(inputWidth), inputHeight(inputHeight), inputFormat(inputFormat), + inputColorspace(inputColorspace), outputWidth(outputWidth), outputHeight(outputHeight) {} -bool SwsFrameContext::operator==(const SwsFrameContext& other) const { +bool SwsConfig::operator==(const SwsConfig& other) const { return inputWidth == other.inputWidth && inputHeight == other.inputHeight && - inputFormat == other.inputFormat && outputWidth == other.outputWidth && - outputHeight == other.outputHeight; + inputFormat == other.inputFormat && + inputColorspace == other.inputColorspace && + outputWidth == other.outputWidth && outputHeight == other.outputHeight; } -bool SwsFrameContext::operator!=(const SwsFrameContext& other) const { +bool SwsConfig::operator!=(const SwsConfig& other) const { return !(*this == other); } UniqueSwsContext createSwsContext( - const SwsFrameContext& swsFrameContext, - AVColorSpace colorspace, + const SwsConfig& swsConfig, AVPixelFormat outputFormat, int swsFlags) { SwsContext* swsContext = sws_getContext( - swsFrameContext.inputWidth, - swsFrameContext.inputHeight, - swsFrameContext.inputFormat, - swsFrameContext.outputWidth, - swsFrameContext.outputHeight, + swsConfig.inputWidth, + swsConfig.inputHeight, + swsConfig.inputFormat, + swsConfig.outputWidth, + swsConfig.outputHeight, outputFormat, swsFlags, nullptr, nullptr, nullptr); - TORCH_CHECK(swsContext, "sws_getContext() returned nullptr"); + STD_TORCH_CHECK(swsContext, "sws_getContext() returned nullptr"); int* invTable = nullptr; int* table = nullptr; @@ -657,9 +735,9 @@ UniqueSwsContext createSwsContext( &brightness, &contrast, &saturation); - TORCH_CHECK(ret != -1, "sws_getColorspaceDetails returned -1"); + STD_TORCH_CHECK(ret != -1, "sws_getColorspaceDetails returned -1"); - const int* colorspaceTable = sws_getCoefficients(colorspace); + const int* colorspaceTable = sws_getCoefficients(swsConfig.inputColorspace); ret = sws_setColorspaceDetails( swsContext, colorspaceTable, @@ -669,7 +747,7 @@ UniqueSwsContext createSwsContext( brightness, contrast, saturation); - TORCH_CHECK(ret != -1, "sws_setColorspaceDetails returned -1"); + STD_TORCH_CHECK(ret != -1, "sws_setColorspaceDetails returned -1"); return UniqueSwsContext(swsContext); } diff --git a/src/torchcodec/_core/FFMPEGCommon.h b/src/torchcodec/_core/FFMPEGCommon.h index 2d58abfb2..77738c3c6 100644 --- a/src/torchcodec/_core/FFMPEGCommon.h +++ b/src/torchcodec/_core/FFMPEGCommon.h @@ -6,8 +6,8 @@ #pragma once -#include #include +#include #include #include @@ -104,6 +104,39 @@ using UniqueAVBufferSrcParameters = std::unique_ptr< AVBufferSrcParameters, Deleterv>; +// Wrapper class for AVDictionary, similar to unique_ptr, to support FFmpeg's +// functions that require a double-pointer to AVDictionary, such as av_dict_set. +// https://ffmpeg.org/doxygen/trunk/group__lavu__dict.html#ga8d9c2de72b310cef8e6a28c9cd3acbbe +class UniqueAVDictionary { + private: + AVDictionary* dict_ = nullptr; + + public: + UniqueAVDictionary() = default; + + ~UniqueAVDictionary() { + if (dict_) { + av_dict_free(&dict_); + } + } + + // Explicitly delete copy operator similar to unique_ptr + UniqueAVDictionary(const UniqueAVDictionary&) = delete; + UniqueAVDictionary& operator=(const UniqueAVDictionary&) = delete; + // Explicitly delete move operator, as it is not needed at this time. + UniqueAVDictionary(UniqueAVDictionary&&) = delete; + UniqueAVDictionary& operator=(UniqueAVDictionary&&) = delete; + + // FFmpeg's AVDictionary functions require a AVDictionary** argument. + // However, unique_ptr's get() function returns a **temporary** pointer to the + // object, so we cannot get a pointer to the internal AVDictionary pointer. + // As a result, we implement getAddress() to return a pointer to the internal + // AVDictionary pointer. + AVDictionary** getAddress() { + return &dict_; + } +}; + // These 2 classes share the same underlying AVPacket object. They are meant to // be used in tandem, like so: // @@ -181,6 +214,7 @@ const AVPixelFormat* getSupportedPixelFormats(const AVCodec& avCodec); int getNumChannels(const UniqueAVFrame& avFrame); int getNumChannels(const SharedAVCodecContext& avCodecContext); +int getNumChannels(const AVCodecParameters* codecpar); void setDefaultChannelLayout( UniqueAVCodecContext& avCodecContext, @@ -247,34 +281,41 @@ int64_t computeSafeDuration( const AVRational& frameRate, const AVRational& timeBase); -AVFilterContext* createBuffersinkFilter( +// Extracts the rotation angle in degrees from the stream's display matrix +// side data. The display matrix is used to specify how the video should be +// rotated for correct display. +std::optional getRotationFromStream(const AVStream* avStream); + +AVFilterContext* createAVFilterContextWithOptions( AVFilterGraph* filterGraph, - enum AVPixelFormat outputFormat); + const AVFilter* buffer, + const enum AVPixelFormat outputFormat); -struct SwsFrameContext { +struct SwsConfig { int inputWidth = 0; int inputHeight = 0; AVPixelFormat inputFormat = AV_PIX_FMT_NONE; + AVColorSpace inputColorspace = AVCOL_SPC_UNSPECIFIED; int outputWidth = 0; int outputHeight = 0; - SwsFrameContext() = default; - SwsFrameContext( + SwsConfig() = default; + SwsConfig( int inputWidth, int inputHeight, AVPixelFormat inputFormat, + AVColorSpace inputColorspace, int outputWidth, int outputHeight); - bool operator==(const SwsFrameContext& other) const; - bool operator!=(const SwsFrameContext& other) const; + bool operator==(const SwsConfig& other) const; + bool operator!=(const SwsConfig& other) const; }; // Utility functions for swscale context management UniqueSwsContext createSwsContext( - const SwsFrameContext& swsFrameContext, - AVColorSpace colorspace, - AVPixelFormat outputFormat = AV_PIX_FMT_RGB24, - int swsFlags = SWS_BILINEAR); + const SwsConfig& swsConfig, + AVPixelFormat outputFormat, + int swsFlags); } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/FilterGraph.cpp b/src/torchcodec/_core/FilterGraph.cpp index 605b814a8..904886908 100644 --- a/src/torchcodec/_core/FilterGraph.cpp +++ b/src/torchcodec/_core/FilterGraph.cpp @@ -4,8 +4,9 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. -#include "src/torchcodec/_core/FilterGraph.h" -#include "src/torchcodec/_core/FFMPEGCommon.h" +#include "FilterGraph.h" +#include "FFMPEGCommon.h" +#include "StableABICompat.h" extern "C" { #include @@ -14,7 +15,7 @@ extern "C" { namespace facebook::torchcodec { -FiltersContext::FiltersContext( +FiltersConfig::FiltersConfig( int inputWidth, int inputHeight, AVPixelFormat inputFormat, @@ -40,7 +41,7 @@ bool operator==(const AVRational& lhs, const AVRational& rhs) { return lhs.num == rhs.num && lhs.den == rhs.den; } -bool FiltersContext::operator==(const FiltersContext& other) const { +bool FiltersConfig::operator==(const FiltersConfig& other) const { return inputWidth == other.inputWidth && inputHeight == other.inputHeight && inputFormat == other.inputFormat && outputWidth == other.outputWidth && outputHeight == other.outputHeight && @@ -49,99 +50,109 @@ bool FiltersContext::operator==(const FiltersContext& other) const { hwFramesCtx.get() == other.hwFramesCtx.get(); } -bool FiltersContext::operator!=(const FiltersContext& other) const { +bool FiltersConfig::operator!=(const FiltersConfig& other) const { return !(*this == other); } FilterGraph::FilterGraph( - const FiltersContext& filtersContext, + const FiltersConfig& filtersConfig, const VideoStreamOptions& videoStreamOptions) { filterGraph_.reset(avfilter_graph_alloc()); - TORCH_CHECK(filterGraph_.get() != nullptr); + STD_TORCH_CHECK( + filterGraph_.get() != nullptr, "Failed to allocate filter graph"); if (videoStreamOptions.ffmpegThreadCount.has_value()) { filterGraph_->nb_threads = videoStreamOptions.ffmpegThreadCount.value(); } - const AVFilter* buffersrc = avfilter_get_by_name("buffer"); - + // Configure the source context. + const AVFilter* bufferSrc = avfilter_get_by_name("buffer"); UniqueAVBufferSrcParameters srcParams(av_buffersrc_parameters_alloc()); - TORCH_CHECK(srcParams, "Failed to allocate buffersrc params"); - - srcParams->format = filtersContext.inputFormat; - srcParams->width = filtersContext.inputWidth; - srcParams->height = filtersContext.inputHeight; - srcParams->sample_aspect_ratio = filtersContext.inputAspectRatio; - srcParams->time_base = filtersContext.timeBase; - if (filtersContext.hwFramesCtx) { - srcParams->hw_frames_ctx = av_buffer_ref(filtersContext.hwFramesCtx.get()); + STD_TORCH_CHECK(srcParams, "Failed to allocate buffersrc params"); + + srcParams->format = filtersConfig.inputFormat; + srcParams->width = filtersConfig.inputWidth; + srcParams->height = filtersConfig.inputHeight; + srcParams->sample_aspect_ratio = filtersConfig.inputAspectRatio; + srcParams->time_base = filtersConfig.timeBase; + if (filtersConfig.hwFramesCtx) { + srcParams->hw_frames_ctx = av_buffer_ref(filtersConfig.hwFramesCtx.get()); } sourceContext_ = - avfilter_graph_alloc_filter(filterGraph_.get(), buffersrc, "in"); - TORCH_CHECK(sourceContext_, "Failed to allocate filter graph"); + avfilter_graph_alloc_filter(filterGraph_.get(), bufferSrc, "in"); + STD_TORCH_CHECK(sourceContext_, "Failed to allocate filter graph"); int status = av_buffersrc_parameters_set(sourceContext_, srcParams.get()); - TORCH_CHECK( + STD_TORCH_CHECK( status >= 0, "Failed to create filter graph: ", getFFMPEGErrorStringFromErrorCode(status)); status = avfilter_init_str(sourceContext_, nullptr); - TORCH_CHECK( + STD_TORCH_CHECK( status >= 0, "Failed to create filter graph : ", getFFMPEGErrorStringFromErrorCode(status)); - sinkContext_ = - createBuffersinkFilter(filterGraph_.get(), filtersContext.outputFormat); - TORCH_CHECK( + // Configure the sink context. + const AVFilter* bufferSink = avfilter_get_by_name("buffersink"); + STD_TORCH_CHECK(bufferSink != nullptr, "Failed to get buffersink filter."); + + sinkContext_ = createAVFilterContextWithOptions( + filterGraph_.get(), bufferSink, filtersConfig.outputFormat); + STD_TORCH_CHECK( sinkContext_ != nullptr, "Failed to create and configure buffersink"); + // Create the filtergraph nodes based on the source and sink contexts. UniqueAVFilterInOut outputs(avfilter_inout_alloc()); - UniqueAVFilterInOut inputs(avfilter_inout_alloc()); - outputs->name = av_strdup("in"); outputs->filter_ctx = sourceContext_; outputs->pad_idx = 0; outputs->next = nullptr; + + UniqueAVFilterInOut inputs(avfilter_inout_alloc()); inputs->name = av_strdup("out"); inputs->filter_ctx = sinkContext_; inputs->pad_idx = 0; inputs->next = nullptr; + // Create the filtergraph specified by the filtergraph string in the context + // of the inputs and outputs. Note the dance we have to do with release and + // resetting the output and input nodes because FFmpeg modifies them in place. AVFilterInOut* outputsTmp = outputs.release(); AVFilterInOut* inputsTmp = inputs.release(); status = avfilter_graph_parse_ptr( filterGraph_.get(), - filtersContext.filtergraphStr.c_str(), + filtersConfig.filtergraphStr.c_str(), &inputsTmp, &outputsTmp, nullptr); outputs.reset(outputsTmp); inputs.reset(inputsTmp); - TORCH_CHECK( + STD_TORCH_CHECK( status >= 0, "Failed to parse filter description: ", getFFMPEGErrorStringFromErrorCode(status), - ", provided filters: " + filtersContext.filtergraphStr); + ", provided filters: " + filtersConfig.filtergraphStr); + // Check filtergraph validity and configure links and formats. status = avfilter_graph_config(filterGraph_.get(), nullptr); - TORCH_CHECK( + STD_TORCH_CHECK( status >= 0, "Failed to configure filter graph: ", getFFMPEGErrorStringFromErrorCode(status), - ", provided filters: " + filtersContext.filtergraphStr); + ", provided filters: " + filtersConfig.filtergraphStr); } UniqueAVFrame FilterGraph::convert(const UniqueAVFrame& avFrame) { int status = av_buffersrc_write_frame(sourceContext_, avFrame.get()); - TORCH_CHECK( + STD_TORCH_CHECK( status >= AVSUCCESS, "Failed to add frame to buffer source context"); UniqueAVFrame filteredAVFrame(av_frame_alloc()); status = av_buffersink_get_frame(sinkContext_, filteredAVFrame.get()); - TORCH_CHECK( + STD_TORCH_CHECK( status >= AVSUCCESS, "Failed to get frame from buffer sink context"); return filteredAVFrame; diff --git a/src/torchcodec/_core/FilterGraph.h b/src/torchcodec/_core/FilterGraph.h index 8cba571bd..4e5257d79 100644 --- a/src/torchcodec/_core/FilterGraph.h +++ b/src/torchcodec/_core/FilterGraph.h @@ -6,12 +6,12 @@ #pragma once -#include "src/torchcodec/_core/FFMPEGCommon.h" -#include "src/torchcodec/_core/StreamOptions.h" +#include "FFMPEGCommon.h" +#include "StreamOptions.h" namespace facebook::torchcodec { -struct FiltersContext { +struct FiltersConfig { int inputWidth = 0; int inputHeight = 0; AVPixelFormat inputFormat = AV_PIX_FMT_NONE; @@ -23,10 +23,10 @@ struct FiltersContext { AVRational timeBase = {0, 0}; UniqueAVBufferRef hwFramesCtx; - FiltersContext() = default; - FiltersContext(FiltersContext&&) = default; - FiltersContext& operator=(FiltersContext&&) = default; - FiltersContext( + FiltersConfig() = default; + FiltersConfig(FiltersConfig&&) = default; + FiltersConfig& operator=(FiltersConfig&&) = default; + FiltersConfig( int inputWidth, int inputHeight, AVPixelFormat inputFormat, @@ -38,14 +38,14 @@ struct FiltersContext { AVRational timeBase, AVBufferRef* hwFramesCtx = nullptr); - bool operator==(const FiltersContext&) const; - bool operator!=(const FiltersContext&) const; + bool operator==(const FiltersConfig&) const; + bool operator!=(const FiltersConfig&) const; }; class FilterGraph { public: FilterGraph( - const FiltersContext& filtersContext, + const FiltersConfig& filtersConfig, const VideoStreamOptions& videoStreamOptions); UniqueAVFrame convert(const UniqueAVFrame& avFrame); diff --git a/src/torchcodec/_core/Frame.cpp b/src/torchcodec/_core/Frame.cpp index 62fb46c65..233305c27 100644 --- a/src/torchcodec/_core/Frame.cpp +++ b/src/torchcodec/_core/Frame.cpp @@ -4,43 +4,48 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. -#include "src/torchcodec/_core/Frame.h" +#include "Frame.h" +#include "StableABICompat.h" namespace facebook::torchcodec { FrameDims::FrameDims(int height, int width) : height(height), width(width) { - TORCH_CHECK(height > 0, "FrameDims.height must be > 0, got: ", height); - TORCH_CHECK(width > 0, "FrameDims.width must be > 0, got: ", width); + STD_TORCH_CHECK(height > 0, "FrameDims.height must be > 0, got: ", height); + STD_TORCH_CHECK(width > 0, "FrameDims.width must be > 0, got: ", width); } FrameBatchOutput::FrameBatchOutput( int64_t numFrames, const FrameDims& outputDims, - const torch::Device& device) - : ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})), - durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) { + const StableDevice& device) + : ptsSeconds(torch::stable::empty({numFrames}, kStableFloat64)), + durationSeconds(torch::stable::empty({numFrames}, kStableFloat64)) { data = allocateEmptyHWCTensor(outputDims, device, numFrames); } -torch::Tensor allocateEmptyHWCTensor( +torch::stable::Tensor allocateEmptyHWCTensor( const FrameDims& frameDims, - const torch::Device& device, + const StableDevice& device, std::optional numFrames) { - auto tensorOptions = torch::TensorOptions() - .dtype(torch::kUInt8) - .layout(torch::kStrided) - .device(device); - TORCH_CHECK( + STD_TORCH_CHECK( frameDims.height > 0, "height must be > 0, got: ", frameDims.height); - TORCH_CHECK(frameDims.width > 0, "width must be > 0, got: ", frameDims.width); + STD_TORCH_CHECK( + frameDims.width > 0, "width must be > 0, got: ", frameDims.width); if (numFrames.has_value()) { auto numFramesValue = numFrames.value(); - TORCH_CHECK( + STD_TORCH_CHECK( numFramesValue >= 0, "numFrames must be >= 0, got: ", numFramesValue); - return torch::empty( - {numFramesValue, frameDims.height, frameDims.width, 3}, tensorOptions); + return torch::stable::empty( + {numFramesValue, frameDims.height, frameDims.width, 3}, + kStableUInt8, + std::nullopt, + device); } else { - return torch::empty({frameDims.height, frameDims.width, 3}, tensorOptions); + return torch::stable::empty( + {frameDims.height, frameDims.width, 3}, + kStableUInt8, + std::nullopt, + device); } } diff --git a/src/torchcodec/_core/Frame.h b/src/torchcodec/_core/Frame.h index 67e4d2b79..5e8baa07c 100644 --- a/src/torchcodec/_core/Frame.h +++ b/src/torchcodec/_core/Frame.h @@ -6,10 +6,10 @@ #pragma once -#include -#include "src/torchcodec/_core/FFMPEGCommon.h" -#include "src/torchcodec/_core/Metadata.h" -#include "src/torchcodec/_core/StreamOptions.h" +#include "FFMPEGCommon.h" +#include "Metadata.h" +#include "StableABICompat.h" +#include "StreamOptions.h" namespace facebook::torchcodec { @@ -33,24 +33,24 @@ struct FrameOutput { // data shape is: // - 3D (C, H, W) or (H, W, C) for videos // - 2D (numChannels, numSamples) for audio - torch::Tensor data; + torch::stable::Tensor data; double ptsSeconds; double durationSeconds; }; struct FrameBatchOutput { - torch::Tensor data; // 4D: of shape NCHW or NHWC. - torch::Tensor ptsSeconds; // 1D of shape (N,) - torch::Tensor durationSeconds; // 1D of shape (N,) + torch::stable::Tensor data; // 4D: of shape NCHW or NHWC. + torch::stable::Tensor ptsSeconds; // 1D of shape (N,) + torch::stable::Tensor durationSeconds; // 1D of shape (N,) FrameBatchOutput( int64_t numFrames, const FrameDims& outputDims, - const torch::Device& device); + const StableDevice& device); }; struct AudioFramesOutput { - torch::Tensor data; // shape is (numChannels, numSamples) + torch::stable::Tensor data; // shape is (numChannels, numSamples) double ptsSeconds; }; @@ -64,9 +64,9 @@ struct AudioFramesOutput { // assume HWC tensors, since this is what FFmpeg natively handles. It's up to // the high-level decoding entry-points to permute that back to CHW, by calling // maybePermuteHWC2CHW(). -torch::Tensor allocateEmptyHWCTensor( +torch::stable::Tensor allocateEmptyHWCTensor( const FrameDims& frameDims, - const torch::Device& device, + const StableDevice& device, std::optional numFrames = std::nullopt); } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/Metadata.cpp b/src/torchcodec/_core/Metadata.cpp new file mode 100644 index 000000000..079c8ef41 --- /dev/null +++ b/src/torchcodec/_core/Metadata.cpp @@ -0,0 +1,164 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "Metadata.h" +#include "StableABICompat.h" + +extern "C" { +#include +} + +namespace facebook::torchcodec { + +std::optional StreamMetadata::getDurationSeconds( + SeekMode seekMode) const { + switch (seekMode) { + case SeekMode::custom_frame_mappings: + case SeekMode::exact: + STD_TORCH_CHECK( + endStreamPtsSecondsFromContent.has_value() && + beginStreamPtsSecondsFromContent.has_value(), + "Missing beginStreamPtsSecondsFromContent or endStreamPtsSecondsFromContent"); + return endStreamPtsSecondsFromContent.value() - + beginStreamPtsSecondsFromContent.value(); + case SeekMode::approximate: + if (durationSecondsFromHeader.has_value()) { + return durationSecondsFromHeader.value(); + } + if (numFramesFromHeader.has_value() && averageFpsFromHeader.has_value() && + averageFpsFromHeader.value() != 0.0) { + return static_cast(numFramesFromHeader.value()) / + averageFpsFromHeader.value(); + } + if (durationSecondsFromContainer.has_value()) { + return durationSecondsFromContainer.value(); + } + return std::nullopt; + default: + STD_TORCH_CHECK(false, "Unknown SeekMode"); + } +} + +double StreamMetadata::getBeginStreamSeconds(SeekMode seekMode) const { + switch (seekMode) { + case SeekMode::custom_frame_mappings: + case SeekMode::exact: + STD_TORCH_CHECK( + beginStreamPtsSecondsFromContent.has_value(), + "Missing beginStreamPtsSecondsFromContent"); + return beginStreamPtsSecondsFromContent.value(); + case SeekMode::approximate: + if (beginStreamSecondsFromHeader.has_value()) { + return beginStreamSecondsFromHeader.value(); + } + return 0.0; + default: + STD_TORCH_CHECK(false, "Unknown SeekMode"); + } +} + +std::optional StreamMetadata::getEndStreamSeconds( + SeekMode seekMode) const { + switch (seekMode) { + case SeekMode::custom_frame_mappings: + case SeekMode::exact: + STD_TORCH_CHECK( + endStreamPtsSecondsFromContent.has_value(), + "Missing endStreamPtsSecondsFromContent"); + return endStreamPtsSecondsFromContent.value(); + case SeekMode::approximate: { + auto dur = getDurationSeconds(seekMode); + if (dur.has_value()) { + return getBeginStreamSeconds(seekMode) + dur.value(); + } + return std::nullopt; + } + default: + STD_TORCH_CHECK(false, "Unknown SeekMode"); + } +} + +std::optional StreamMetadata::getNumFrames(SeekMode seekMode) const { + switch (seekMode) { + case SeekMode::custom_frame_mappings: + case SeekMode::exact: + STD_TORCH_CHECK( + numFramesFromContent.has_value(), "Missing numFramesFromContent"); + return numFramesFromContent.value(); + case SeekMode::approximate: { + auto durationSeconds = getDurationSeconds(seekMode); + if (numFramesFromHeader.has_value()) { + return numFramesFromHeader.value(); + } + if (averageFpsFromHeader.has_value() && durationSeconds.has_value()) { + return static_cast( + averageFpsFromHeader.value() * durationSeconds.value()); + } + return std::nullopt; + } + default: + STD_TORCH_CHECK(false, "Unknown SeekMode"); + } +} + +std::optional StreamMetadata::getAverageFps(SeekMode seekMode) const { + switch (seekMode) { + case SeekMode::custom_frame_mappings: + case SeekMode::exact: { + auto numFrames = getNumFrames(seekMode); + if (numFrames.has_value() && + beginStreamPtsSecondsFromContent.has_value() && + endStreamPtsSecondsFromContent.has_value()) { + double duration = endStreamPtsSecondsFromContent.value() - + beginStreamPtsSecondsFromContent.value(); + if (duration != 0.0) { + return static_cast(numFrames.value()) / duration; + } + } + return averageFpsFromHeader; + } + case SeekMode::approximate: + return averageFpsFromHeader; + default: + STD_TORCH_CHECK(false, "Unknown SeekMode"); + } +} + +std::optional StreamMetadata::getColorPrimariesName() const { + if (!colorPrimaries.has_value()) { + return std::nullopt; + } + const char* name = av_color_primaries_name(*colorPrimaries); + if (name == nullptr) { + return std::nullopt; + } + return std::string(name); +} + +std::optional StreamMetadata::getColorSpaceName() const { + if (!colorSpace.has_value()) { + return std::nullopt; + } + const char* name = av_color_space_name(*colorSpace); + if (name == nullptr) { + return std::nullopt; + } + return std::string(name); +} + +std::optional StreamMetadata::getColorTransferCharacteristicName() + const { + if (!colorTransferCharacteristic.has_value()) { + return std::nullopt; + } + const char* name = av_color_transfer_name(*colorTransferCharacteristic); + if (name == nullptr) { + return std::nullopt; + } + return std::string(name); +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/Metadata.h b/src/torchcodec/_core/Metadata.h index ace6cf84c..41d6046b9 100644 --- a/src/torchcodec/_core/Metadata.h +++ b/src/torchcodec/_core/Metadata.h @@ -13,17 +13,22 @@ extern "C" { #include #include +#include #include } namespace facebook::torchcodec { +enum class SeekMode { exact, approximate, custom_frame_mappings }; + struct StreamMetadata { // Common (video and audio) fields derived from the AVStream. - int streamIndex; + int streamIndex = -1; + // See this link for what various values are available: // https://ffmpeg.org/doxygen/trunk/group__lavu__misc.html#ga9a84bba4713dfced21a1a56163be1f48 - AVMediaType mediaType; + AVMediaType mediaType = AVMEDIA_TYPE_UNKNOWN; + std::optional codecId; std::optional codecName; std::optional durationSecondsFromHeader; @@ -33,38 +38,68 @@ struct StreamMetadata { std::optional averageFpsFromHeader; std::optional bitRate; + // Used as fallback in approximate mode when stream duration is unavailable. + std::optional durationSecondsFromContainer; + // More accurate duration, obtained by scanning the file. // These presentation timestamps are in time base. std::optional beginStreamPtsFromContent; std::optional endStreamPtsFromContent; + // These presentation timestamps are in seconds. std::optional beginStreamPtsSecondsFromContent; std::optional endStreamPtsSecondsFromContent; + // This can be useful for index-based seeking. std::optional numFramesFromContent; - // Video-only fields derived from the AVCodecContext. - std::optional width; - std::optional height; + // Video-only fields + // Post-rotation dimensions + std::optional postRotationWidth; + std::optional postRotationHeight; std::optional sampleAspectRatio; + // Rotation angle in degrees from display matrix, in the range [-180, 180]. + std::optional rotation; + std::optional colorPrimaries; + std::optional colorSpace; + std::optional colorTransferCharacteristic; + // The pixel format of the encoded video, e.g. "yuv420p". + std::optional pixelFormat; // Audio-only fields std::optional sampleRate; std::optional numChannels; std::optional sampleFormat; + + // Computed methods with fallback logic + std::optional getDurationSeconds(SeekMode seekMode) const; + double getBeginStreamSeconds(SeekMode seekMode) const; + std::optional getEndStreamSeconds(SeekMode seekMode) const; + std::optional getNumFrames(SeekMode seekMode) const; + std::optional getAverageFps(SeekMode seekMode) const; + + // Color metadata name accessors. These return nullopt if the field is unset + // or if FFmpeg returns NULL for the name. + std::optional getColorPrimariesName() const; + std::optional getColorSpaceName() const; + std::optional getColorTransferCharacteristicName() const; }; struct ContainerMetadata { std::vector allStreamMetadata; int numAudioStreams = 0; int numVideoStreams = 0; + // Note that this is the container-level duration, which is usually the max // of all stream durations available in the container. std::optional durationSecondsFromHeader; + // Total BitRate level information at the container level in bit/s std::optional bitRate; + // If set, this is the index to the default audio stream. std::optional bestAudioStreamIndex; + // If set, this is the index to the default video stream. std::optional bestVideoStreamIndex; }; diff --git a/src/torchcodec/_core/NVCUVIDRuntimeLoader.cpp b/src/torchcodec/_core/NVCUVIDRuntimeLoader.cpp index 2bb501fc2..2ffc72faf 100644 --- a/src/torchcodec/_core/NVCUVIDRuntimeLoader.cpp +++ b/src/torchcodec/_core/NVCUVIDRuntimeLoader.cpp @@ -5,6 +5,8 @@ // LICENSE file in the root directory of this source tree. #ifdef FBCODE_CAFFE2 +#include "StableABICompat.h" + // No need to do anything on fbcode. NVCUVID is available there, we can take a // hard dependency on it. // The FBCODE_CAFFE2 macro is defined in the upstream fbcode build of torch, so @@ -17,14 +19,15 @@ bool loadNVCUVIDLibrary() { } // namespace facebook::torchcodec #else -#include "src/torchcodec/_core/NVCUVIDRuntimeLoader.h" +#include "NVCUVIDRuntimeLoader.h" +#include "StableABICompat.h" -#include "src/torchcodec/_core/nvcuvid_include/cuviddec.h" -#include "src/torchcodec/_core/nvcuvid_include/nvcuvid.h" +#include "nvcuvid_include/cuviddec.h" +#include "nvcuvid_include/nvcuvid.h" -#include #include #include +#include "StableABICompat.h" #if defined(WIN64) || defined(_WIN64) #include @@ -213,7 +216,7 @@ extern "C" { CUresult CUDAAPI cuvidCreateVideoParser( CUvideoparser* videoParser, CUVIDPARSERPARAMS* parserParams) { - TORCH_CHECK( + STD_TORCH_CHECK( facebook::torchcodec::dl_cuvidCreateVideoParser, "cuvidCreateVideoParser called but NVCUVID not loaded!"); return facebook::torchcodec::dl_cuvidCreateVideoParser( @@ -223,21 +226,21 @@ CUresult CUDAAPI cuvidCreateVideoParser( CUresult CUDAAPI cuvidParseVideoData( CUvideoparser videoParser, CUVIDSOURCEDATAPACKET* cuvidPacket) { - TORCH_CHECK( + STD_TORCH_CHECK( facebook::torchcodec::dl_cuvidParseVideoData, "cuvidParseVideoData called but NVCUVID not loaded!"); return facebook::torchcodec::dl_cuvidParseVideoData(videoParser, cuvidPacket); } CUresult CUDAAPI cuvidDestroyVideoParser(CUvideoparser videoParser) { - TORCH_CHECK( + STD_TORCH_CHECK( facebook::torchcodec::dl_cuvidDestroyVideoParser, "cuvidDestroyVideoParser called but NVCUVID not loaded!"); return facebook::torchcodec::dl_cuvidDestroyVideoParser(videoParser); } CUresult CUDAAPI cuvidGetDecoderCaps(CUVIDDECODECAPS* caps) { - TORCH_CHECK( + STD_TORCH_CHECK( facebook::torchcodec::dl_cuvidGetDecoderCaps, "cuvidGetDecoderCaps called but NVCUVID not loaded!"); return facebook::torchcodec::dl_cuvidGetDecoderCaps(caps); @@ -246,14 +249,14 @@ CUresult CUDAAPI cuvidGetDecoderCaps(CUVIDDECODECAPS* caps) { CUresult CUDAAPI cuvidCreateDecoder( CUvideodecoder* decoder, CUVIDDECODECREATEINFO* decoderParams) { - TORCH_CHECK( + STD_TORCH_CHECK( facebook::torchcodec::dl_cuvidCreateDecoder, "cuvidCreateDecoder called but NVCUVID not loaded!"); return facebook::torchcodec::dl_cuvidCreateDecoder(decoder, decoderParams); } CUresult CUDAAPI cuvidDestroyDecoder(CUvideodecoder decoder) { - TORCH_CHECK( + STD_TORCH_CHECK( facebook::torchcodec::dl_cuvidDestroyDecoder, "cuvidDestroyDecoder called but NVCUVID not loaded!"); return facebook::torchcodec::dl_cuvidDestroyDecoder(decoder); @@ -261,7 +264,7 @@ CUresult CUDAAPI cuvidDestroyDecoder(CUvideodecoder decoder) { CUresult CUDAAPI cuvidDecodePicture(CUvideodecoder decoder, CUVIDPICPARAMS* picParams) { - TORCH_CHECK( + STD_TORCH_CHECK( facebook::torchcodec::dl_cuvidDecodePicture, "cuvidDecodePicture called but NVCUVID not loaded!"); return facebook::torchcodec::dl_cuvidDecodePicture(decoder, picParams); @@ -278,7 +281,7 @@ CUresult CUDAAPI cuvidMapVideoFrame( unsigned int* framePtr, unsigned int* pitch, CUVIDPROCPARAMS* procParams) { - TORCH_CHECK( + STD_TORCH_CHECK( facebook::torchcodec::dl_cuvidMapVideoFrame, "cuvidMapVideoFrame called but NVCUVID not loaded!"); return facebook::torchcodec::dl_cuvidMapVideoFrame( @@ -287,7 +290,7 @@ CUresult CUDAAPI cuvidMapVideoFrame( CUresult CUDAAPI cuvidUnmapVideoFrame(CUvideodecoder decoder, unsigned int framePtr) { - TORCH_CHECK( + STD_TORCH_CHECK( facebook::torchcodec::dl_cuvidUnmapVideoFrame, "cuvidUnmapVideoFrame called but NVCUVID not loaded!"); return facebook::torchcodec::dl_cuvidUnmapVideoFrame(decoder, framePtr); @@ -300,7 +303,7 @@ CUresult CUDAAPI cuvidMapVideoFrame64( unsigned long long* framePtr, unsigned int* pitch, CUVIDPROCPARAMS* procParams) { - TORCH_CHECK( + STD_TORCH_CHECK( facebook::torchcodec::dl_cuvidMapVideoFrame64, "cuvidMapVideoFrame64 called but NVCUVID not loaded!"); return facebook::torchcodec::dl_cuvidMapVideoFrame64( @@ -309,7 +312,7 @@ CUresult CUDAAPI cuvidMapVideoFrame64( CUresult CUDAAPI cuvidUnmapVideoFrame64(CUvideodecoder decoder, unsigned long long framePtr) { - TORCH_CHECK( + STD_TORCH_CHECK( facebook::torchcodec::dl_cuvidUnmapVideoFrame64, "cuvidUnmapVideoFrame64 called but NVCUVID not loaded!"); return facebook::torchcodec::dl_cuvidUnmapVideoFrame64(decoder, framePtr); diff --git a/src/torchcodec/_core/NVDECCache.cpp b/src/torchcodec/_core/NVDECCache.cpp index 302433cd4..13b841072 100644 --- a/src/torchcodec/_core/NVDECCache.cpp +++ b/src/torchcodec/_core/NVDECCache.cpp @@ -4,12 +4,12 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. -#include #include -#include "src/torchcodec/_core/CUDACommon.h" -#include "src/torchcodec/_core/FFMPEGCommon.h" -#include "src/torchcodec/_core/NVDECCache.h" +#include "CUDACommon.h" +#include "FFMPEGCommon.h" +#include "NVDECCache.h" +#include "NVDECCacheConfig.h" #include // For cudaGetDevice @@ -20,18 +20,23 @@ extern "C" { namespace facebook::torchcodec { -NVDECCache& NVDECCache::getCache(const torch::Device& device) { +NVDECCache* NVDECCache::getCacheInstances() { static NVDECCache cacheInstances[MAX_CUDA_GPUS]; - return cacheInstances[getDeviceIndex(device)]; + return cacheInstances; +} + +NVDECCache& NVDECCache::getCache(const StableDevice& device) { + return getCacheInstances()[getDeviceIndex(device)]; } UniqueCUvideodecoder NVDECCache::getDecoder(CUVIDEOFORMAT* videoFormat) { CacheKey key(videoFormat); std::lock_guard lock(cacheLock_); + // Find an entry with matching key auto it = cache_.find(key); if (it != cache_.end()) { - auto decoder = std::move(it->second); + auto decoder = std::move(it->second.decoder); cache_.erase(it); return decoder; } @@ -39,22 +44,63 @@ UniqueCUvideodecoder NVDECCache::getDecoder(CUVIDEOFORMAT* videoFormat) { return nullptr; } -bool NVDECCache::returnDecoder( +// Evicts the least-recently-used entry from cache_. +// Caller must hold cacheLock_!!! +void NVDECCache::evictLRUEntry() { + if (cache_.empty()) { + return; + } + auto victim = cache_.begin(); + for (auto it = cache_.begin(); it != cache_.end(); ++it) { + if (it->second.lastUsed < victim->second.lastUsed) { + victim = it; + } + } + cache_.erase(victim); +} + +void NVDECCache::returnDecoder( CUVIDEOFORMAT* videoFormat, UniqueCUvideodecoder decoder) { - if (!decoder) { - return false; - } + STD_TORCH_CHECK(decoder != nullptr, "decoder must not be null"); CacheKey key(videoFormat); std::lock_guard lock(cacheLock_); - if (cache_.size() >= MAX_CACHE_SIZE) { - return false; + int capacity = getNVDECCacheCapacity(); + if (capacity <= 0) { + return; } - cache_[key] = std::move(decoder); - return true; + // Evict least recently used entries until under capacity. + // This search is O(capacity), which is supposed to be small, + // so linear vs constant search overhead is expected to be negligible. + while (cache_.size() >= static_cast(capacity)) { + evictLRUEntry(); + } + + // Add the decoder back to cache + cache_.emplace(key, CacheEntry(std::move(decoder), lastUsedCounter_++)); + + STD_TORCH_CHECK( + cache_.size() <= static_cast(capacity), + "Cache size exceeded capacity, please report a bug"); +} + +void NVDECCache::evictExcessEntriesAcrossDevices(int capacity) { + NVDECCache* instances = getCacheInstances(); + for (int i = 0; i < MAX_CUDA_GPUS; ++i) { + std::lock_guard lock(instances[i].cacheLock_); + while (instances[i].cache_.size() > static_cast(capacity)) { + instances[i].evictLRUEntry(); + } + } +} + +int NVDECCache::getCacheSizeForDevice(int device_index) { + NVDECCache* instances = getCacheInstances(); + std::lock_guard lock(instances[device_index].cacheLock_); + return static_cast(instances[device_index].cache_.size()); } } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/NVDECCache.h b/src/torchcodec/_core/NVDECCache.h index a0f2fb862..0d373e998 100644 --- a/src/torchcodec/_core/NVDECCache.h +++ b/src/torchcodec/_core/NVDECCache.h @@ -11,11 +11,12 @@ #include #include -#include -#include "src/torchcodec/_core/NVCUVIDRuntimeLoader.h" -#include "src/torchcodec/_core/nvcuvid_include/cuviddec.h" -#include "src/torchcodec/_core/nvcuvid_include/nvcuvid.h" +#include "NVCUVIDRuntimeLoader.h" +#include "NVDECCacheConfig.h" +#include "StableABICompat.h" +#include "nvcuvid_include/cuviddec.h" +#include "nvcuvid_include/nvcuvid.h" namespace facebook::torchcodec { @@ -35,17 +36,33 @@ struct CUvideoDecoderDeleter { using UniqueCUvideodecoder = std::unique_ptr; -// A per-device cache for NVDEC decoders. There is one instance of this class -// per GPU device, and it is accessed through the static getCache() method. +struct CacheEntry { + UniqueCUvideodecoder decoder; + uint64_t lastUsed; // LRU timestamp + + CacheEntry(UniqueCUvideodecoder dec, uint64_t ts) + : decoder(std::move(dec)), lastUsed(ts) {} +}; + +// A per-device LRU cache for NVDEC decoders. There is one instance of this +// class per GPU device, and it is accessed through the static getCache() +// method. The cache supports multiple decoders with the same parameters. class NVDECCache { public: - static NVDECCache& getCache(const torch::Device& device); + static NVDECCache& getCache(const StableDevice& device); - // Get decoder from cache - returns nullptr if none available + // Get decoder from cache - returns nullptr if none available. UniqueCUvideodecoder getDecoder(CUVIDEOFORMAT* videoFormat); - // Return decoder to cache - returns true if added to cache - bool returnDecoder(CUVIDEOFORMAT* videoFormat, UniqueCUvideodecoder decoder); + // Return decoder to cache using LRU eviction. + void returnDecoder(CUVIDEOFORMAT* videoFormat, UniqueCUvideodecoder decoder); + + // Iterates all per-device cache instances and evicts LRU entries until each + // cache's size is at most capacity. Called from setNVDECCacheCapacity(). + static void evictExcessEntriesAcrossDevices(int capacity); + + // Returns the number of entries in the cache for a given device index. + static int getCacheSizeForDevice(int device_index); private: // Cache key struct: a decoder can be reused and taken from the cache only if @@ -60,13 +77,15 @@ class NVDECCache { CacheKey() = delete; - explicit CacheKey(CUVIDEOFORMAT* videoFormat) - : codecType(videoFormat->codec), - width(videoFormat->coded_width), - height(videoFormat->coded_height), - chromaFormat(videoFormat->chroma_format), - bitDepthLumaMinus8(videoFormat->bit_depth_luma_minus8), - numDecodeSurfaces(videoFormat->min_num_decode_surfaces) {} + explicit CacheKey(CUVIDEOFORMAT* videoFormat) { + STD_TORCH_CHECK(videoFormat != nullptr, "videoFormat must not be null"); + codecType = videoFormat->codec; + width = videoFormat->coded_width; + height = videoFormat->coded_height; + chromaFormat = videoFormat->chroma_format; + bitDepthLumaMinus8 = videoFormat->bit_depth_luma_minus8; + numDecodeSurfaces = videoFormat->min_num_decode_surfaces; + } CacheKey(const CacheKey&) = default; CacheKey& operator=(const CacheKey&) = default; @@ -92,11 +111,13 @@ class NVDECCache { NVDECCache() = default; ~NVDECCache() = default; - std::map cache_; - std::mutex cacheLock_; + void evictLRUEntry(); - // Max number of cached decoders, per device - static constexpr int MAX_CACHE_SIZE = 20; + static NVDECCache* getCacheInstances(); + + std::multimap cache_; + std::mutex cacheLock_; + uint64_t lastUsedCounter_ = 0; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/NVDECCacheConfig.cpp b/src/torchcodec/_core/NVDECCacheConfig.cpp new file mode 100644 index 000000000..f86079385 --- /dev/null +++ b/src/torchcodec/_core/NVDECCacheConfig.cpp @@ -0,0 +1,62 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "NVDECCacheConfig.h" + +#include +#include + +#include "c10/util/Exception.h" + +#ifdef USE_CUDA +#include "CUDACommon.h" +#include "NVDECCache.h" +#endif + +namespace facebook::torchcodec { + +static std::atomic g_nvdecCacheCapacity{DEFAULT_NVDEC_CACHE_CAPACITY}; +// This mutex serializes setNVDECCacheCapacity() calls so that the atomic store +// and the subsequent cache eviction happen as one unit. getNVDECCacheCapacity() +// intentionally reads the atomic without this mutex: callers like +// returnDecoder() may briefly see a stale value during an ongoing +// setNVDECCacheCapacity(), which is acceptable because the worst case is a +// single decoder being added back to the cache after eviction. That entry will +// be consumed by a subsequent getDecoder() call or evicted by a future +// returnDecoder() or setNVDECCacheCapacity() call. +static std::mutex g_nvdecCacheCapacityMutex; + +void setNVDECCacheCapacity(int capacity) { + TORCH_CHECK( + capacity >= 0, + "NVDEC cache capacity must be non-negative, got ", + capacity); + std::lock_guard lock(g_nvdecCacheCapacityMutex); + g_nvdecCacheCapacity.store(capacity); +#ifdef USE_CUDA + NVDECCache::evictExcessEntriesAcrossDevices(capacity); +#endif +} + +int getNVDECCacheCapacity() { + return g_nvdecCacheCapacity.load(); +} + +int getNVDECCacheSize([[maybe_unused]] int device_index) { +#ifdef USE_CUDA + TORCH_CHECK( + device_index >= 0 && device_index < MAX_CUDA_GPUS, + "device_index must be between 0 and ", + MAX_CUDA_GPUS - 1, + ", got ", + device_index); + return NVDECCache::getCacheSizeForDevice(device_index); +#else + return 0; +#endif +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/NVDECCacheConfig.h b/src/torchcodec/_core/NVDECCacheConfig.h new file mode 100644 index 000000000..4fc5764a6 --- /dev/null +++ b/src/torchcodec/_core/NVDECCacheConfig.h @@ -0,0 +1,30 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +// This header is intentionally CUDA-free so it can be included from +// custom_ops.cpp which is compiled without CUDA headers. + +namespace facebook::torchcodec { + +// Default capacity of the per-device NVDEC decoder cache. +// capacity == maximum number of cached instances allowed. +constexpr int DEFAULT_NVDEC_CACHE_CAPACITY = 20; + +// Set the capacity of the per-device NVDEC decoder cache. +// capacity must be non-negative. +void setNVDECCacheCapacity(int capacity); + +// Get the current capacity of the per-device NVDEC decoder cache. +int getNVDECCacheCapacity(); + +// Get the current number of entries in the NVDEC decoder cache for a device. +// This is currently only used for tests, and not publicly exposed. +// TODO expose it? +int getNVDECCacheSize(int device_index); + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index bd87f12d3..32b2b1d65 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -4,16 +4,17 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. -#include "src/torchcodec/_core/SingleStreamDecoder.h" +#include "SingleStreamDecoder.h" #include #include #include -#include -#include -#include #include #include "Metadata.h" -#include "torch/types.h" +#include "StableABICompat.h" + +extern "C" { +#include +} namespace facebook::torchcodec { namespace { @@ -47,11 +48,11 @@ SingleStreamDecoder::SingleStreamDecoder( AVFormatContext* rawContext = nullptr; int status = avformat_open_input(&rawContext, videoFilePath.c_str(), nullptr, nullptr); - TORCH_CHECK( + STD_TORCH_CHECK( status == 0, "Could not open input file: " + videoFilePath + " " + getFFMPEGErrorStringFromErrorCode(status)); - TORCH_CHECK(rawContext != nullptr); + STD_TORCH_CHECK(rawContext != nullptr, "Failed to allocate AVFormatContext"); formatContext_.reset(rawContext); initializeDecoder(); @@ -63,19 +64,19 @@ SingleStreamDecoder::SingleStreamDecoder( : seekMode_(seekMode), avioContextHolder_(std::move(context)) { setFFmpegLogLevel(); - TORCH_CHECK(avioContextHolder_, "Context holder cannot be null"); + STD_TORCH_CHECK(avioContextHolder_, "Context holder cannot be null"); // Because FFmpeg requires a reference to a pointer in the call to open, we // can't use a unique pointer here. Note that means we must call free if open // fails. AVFormatContext* rawContext = avformat_alloc_context(); - TORCH_CHECK(rawContext != nullptr, "Unable to alloc avformat context"); + STD_TORCH_CHECK(rawContext != nullptr, "Unable to alloc avformat context"); rawContext->pb = avioContextHolder_->getAVIOContext(); int status = avformat_open_input(&rawContext, nullptr, nullptr, nullptr); if (status != 0) { avformat_free_context(rawContext); - TORCH_CHECK( + STD_TORCH_CHECK( false, "Failed to open input buffer: " + getFFMPEGErrorStringFromErrorCode(status)); @@ -87,7 +88,7 @@ SingleStreamDecoder::SingleStreamDecoder( } void SingleStreamDecoder::initializeDecoder() { - TORCH_CHECK(!initialized_, "Attempted double initialization."); + STD_TORCH_CHECK(!initialized_, "Attempted double initialization."); // In principle, the AVFormatContext should be filled in by the call to // avformat_open_input() which reads the header. However, some formats do not @@ -95,23 +96,43 @@ void SingleStreamDecoder::initializeDecoder() { // which decodes a few frames to get missing info. For more, see: // https://ffmpeg.org/doxygen/7.0/group__lavf__decoding.html int status = avformat_find_stream_info(formatContext_.get(), nullptr); - TORCH_CHECK( + STD_TORCH_CHECK( status >= 0, "Failed to find stream info: ", getFFMPEGErrorStringFromErrorCode(status)); + if (formatContext_->duration > 0) { + AVRational defaultTimeBase{1, AV_TIME_BASE}; + containerMetadata_.durationSecondsFromHeader = + ptsToSeconds(formatContext_->duration, defaultTimeBase); + } + + if (formatContext_->bit_rate > 0) { + containerMetadata_.bitRate = formatContext_->bit_rate; + } + + int bestVideoStream = getBestStreamIndex(AVMEDIA_TYPE_VIDEO); + if (bestVideoStream >= 0) { + containerMetadata_.bestVideoStreamIndex = bestVideoStream; + } + + int bestAudioStream = getBestStreamIndex(AVMEDIA_TYPE_AUDIO); + if (bestAudioStream >= 0) { + containerMetadata_.bestAudioStreamIndex = bestAudioStream; + } + for (unsigned int i = 0; i < formatContext_->nb_streams; i++) { AVStream* avStream = formatContext_->streams[i]; StreamMetadata streamMetadata; - TORCH_CHECK( + STD_TORCH_CHECK( static_cast(i) == avStream->index, "Our stream index, " + std::to_string(i) + ", does not match AVStream's index, " + std::to_string(avStream->index) + "."); streamMetadata.streamIndex = i; - streamMetadata.mediaType = avStream->codecpar->codec_type; streamMetadata.codecName = avcodec_get_name(avStream->codecpar->codec_id); + streamMetadata.mediaType = avStream->codecpar->codec_type; streamMetadata.bitRate = avStream->codecpar->bit_rate; int64_t frameCount = avStream->nb_frames; @@ -133,10 +154,52 @@ void SingleStreamDecoder::initializeDecoder() { if (fps > 0) { streamMetadata.averageFpsFromHeader = fps; } + streamMetadata.rotation = getRotationFromStream(avStream); + + // Report post-rotation dimensions: swap width/height for 90 or -90 + // degree rotations so metadata matches what the decoder returns. + int width = avStream->codecpar->width; + int height = avStream->codecpar->height; + Rotation rotation = rotationFromDegrees(streamMetadata.rotation); + // 90° rotations swap dimensions + if (rotation == Rotation::CCW90 || rotation == Rotation::CW90) { + std::swap(width, height); + } + streamMetadata.postRotationWidth = width; + streamMetadata.postRotationHeight = height; + + streamMetadata.sampleAspectRatio = + avStream->codecpar->sample_aspect_ratio; + + if (avStream->codecpar->color_primaries != AVCOL_PRI_UNSPECIFIED) { + streamMetadata.colorPrimaries = avStream->codecpar->color_primaries; + } + if (avStream->codecpar->color_space != AVCOL_SPC_UNSPECIFIED) { + streamMetadata.colorSpace = avStream->codecpar->color_space; + } + if (avStream->codecpar->color_trc != AVCOL_TRC_UNSPECIFIED) { + streamMetadata.colorTransferCharacteristic = + avStream->codecpar->color_trc; + } + AVPixelFormat pixelFormat = + static_cast(avStream->codecpar->format); + // If the AVPixelFormat is not recognized, we get back nullptr. We have + // to make sure we don't initialize a std::string with nullptr. There's + // nothing to do on the else branch because we're already using an + // optional; it'll just remain empty. + const char* rawPixelFormat = av_get_pix_fmt_name(pixelFormat); + if (rawPixelFormat != nullptr) { + streamMetadata.pixelFormat = std::string(rawPixelFormat); + } + containerMetadata_.numVideoStreams++; } else if (avStream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) { AVSampleFormat format = static_cast(avStream->codecpar->format); + streamMetadata.sampleRate = + static_cast(avStream->codecpar->sample_rate); + streamMetadata.numChannels = + static_cast(getNumChannels(avStream->codecpar)); // If the AVSampleFormat is not recognized, we get back nullptr. We have // to make sure we don't initialize a std::string with nullptr. There's @@ -149,27 +212,10 @@ void SingleStreamDecoder::initializeDecoder() { containerMetadata_.numAudioStreams++; } - containerMetadata_.allStreamMetadata.push_back(streamMetadata); - } - - if (formatContext_->duration > 0) { - AVRational defaultTimeBase{1, AV_TIME_BASE}; - containerMetadata_.durationSecondsFromHeader = - ptsToSeconds(formatContext_->duration, defaultTimeBase); - } - - if (formatContext_->bit_rate > 0) { - containerMetadata_.bitRate = formatContext_->bit_rate; - } - - int bestVideoStream = getBestStreamIndex(AVMEDIA_TYPE_VIDEO); - if (bestVideoStream >= 0) { - containerMetadata_.bestVideoStreamIndex = bestVideoStream; - } + streamMetadata.durationSecondsFromContainer = + containerMetadata_.durationSecondsFromHeader; - int bestAudioStream = getBestStreamIndex(AVMEDIA_TYPE_AUDIO); - if (bestAudioStream >= 0) { - containerMetadata_.bestAudioStreamIndex = bestAudioStream; + containerMetadata_.allStreamMetadata.push_back(streamMetadata); } if (seekMode_ == SeekMode::exact) { @@ -213,7 +259,7 @@ void SingleStreamDecoder::sortAllFrames() { for (size_t i = 0; i < streamInfo.allFrames.size(); ++i) { streamInfo.allFrames[i].frameIndex = i; if (streamInfo.allFrames[i].isKeyFrame) { - TORCH_CHECK( + STD_TORCH_CHECK( keyFrameIndex < streamInfo.keyFrames.size(), "The allFrames vec claims it has MORE keyFrames than the keyFrames vec. There's a bug in torchcodec."); streamInfo.keyFrames[keyFrameIndex].frameIndex = i; @@ -223,7 +269,7 @@ void SingleStreamDecoder::sortAllFrames() { streamInfo.allFrames[i].nextPts = streamInfo.allFrames[i + 1].pts; } } - TORCH_CHECK( + STD_TORCH_CHECK( keyFrameIndex == streamInfo.keyFrames.size(), "The allFrames vec claims it has LESS keyFrames than the keyFrames vec. There's a bug in torchcodec."); } @@ -245,7 +291,7 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() { break; } - TORCH_CHECK( + STD_TORCH_CHECK( status == AVSUCCESS, "Failed to read frame from input file: ", getFFMPEGErrorStringFromErrorCode(status)); @@ -288,6 +334,14 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() { streamMetadata.numFramesFromContent = streamInfos_[streamIndex].allFrames.size(); + // This ensures that we are robust in handling cases where + // we are decoding in exact mode and numFrames is 0. The current metadata + // validation logic assumes that these values should not be None + if (streamMetadata.numFramesFromContent.value() == 0) { + streamMetadata.beginStreamPtsFromContent = 0; + streamMetadata.endStreamPtsFromContent = 0; + } + if (streamMetadata.beginStreamPtsFromContent.has_value()) { streamMetadata.beginStreamPtsSecondsFromContent = ptsToSeconds( *streamMetadata.beginStreamPtsFromContent, avStream->time_base); @@ -300,7 +354,7 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() { // Reset the seek-cursor back to the beginning. int status = avformat_seek_file(formatContext_.get(), 0, INT64_MIN, 0, 0, 0); - TORCH_CHECK( + STD_TORCH_CHECK( status >= 0, "Could not seek file to pts=0: ", getFFMPEGErrorStringFromErrorCode(status)); @@ -313,29 +367,26 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() { void SingleStreamDecoder::readCustomFrameMappingsUpdateMetadataAndIndex( int streamIndex, FrameMappings customFrameMappings) { - TORCH_CHECK( - customFrameMappings.all_frames.dtype() == torch::kLong && - customFrameMappings.is_key_frame.dtype() == torch::kBool && - customFrameMappings.duration.dtype() == torch::kLong, + STD_TORCH_CHECK( + customFrameMappings.all_frames.scalar_type() == kStableInt64 && + customFrameMappings.is_key_frame.scalar_type() == kStableBool && + customFrameMappings.duration.scalar_type() == kStableInt64, "all_frames and duration tensors must be int64 dtype, and is_key_frame tensor must be a bool dtype."); - const torch::Tensor& all_frames = - customFrameMappings.all_frames.to(torch::kLong); - const torch::Tensor& is_key_frame = - customFrameMappings.is_key_frame.to(torch::kBool); - const torch::Tensor& duration = customFrameMappings.duration.to(torch::kLong); - TORCH_CHECK( - all_frames.size(0) == is_key_frame.size(0) && - is_key_frame.size(0) == duration.size(0), + const torch::stable::Tensor& all_frames = customFrameMappings.all_frames; + const torch::stable::Tensor& is_key_frame = customFrameMappings.is_key_frame; + const torch::stable::Tensor& duration = customFrameMappings.duration; + STD_TORCH_CHECK( + all_frames.sizes()[0] == is_key_frame.sizes()[0] && + is_key_frame.sizes()[0] == duration.sizes()[0], "all_frames, is_key_frame, and duration from custom_frame_mappings were not same size."); // Allocate vectors using num frames to reduce reallocations - int64_t numFrames = all_frames.size(0); + int64_t numFrames = all_frames.sizes()[0]; streamInfos_[streamIndex].allFrames.reserve(numFrames); streamInfos_[streamIndex].keyFrames.reserve(numFrames); - // Use accessor to efficiently access tensor elements - auto pts_data = all_frames.accessor(); - auto is_key_frame_data = is_key_frame.accessor(); - auto duration_data = duration.accessor(); + auto pts_data = constAccessor(all_frames); + auto is_key_frame_data = constAccessor(is_key_frame); + auto duration_data = constAccessor(duration); auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex]; @@ -367,16 +418,25 @@ ContainerMetadata SingleStreamDecoder::getContainerMetadata() const { return containerMetadata_; } -torch::Tensor SingleStreamDecoder::getKeyFrameIndices() { +SeekMode SingleStreamDecoder::getSeekMode() const { + return seekMode_; +} + +int SingleStreamDecoder::getActiveStreamIndex() const { + return activeStreamIndex_; +} + +torch::stable::Tensor SingleStreamDecoder::getKeyFrameIndices() { validateActiveStream(AVMEDIA_TYPE_VIDEO); validateScannedAllStreams("getKeyFrameIndices"); const std::vector& keyFrames = streamInfos_[activeStreamIndex_].keyFrames; - torch::Tensor keyFrameIndices = - torch::empty({static_cast(keyFrames.size())}, {torch::kInt64}); + torch::stable::Tensor keyFrameIndices = torch::stable::empty( + {static_cast(keyFrames.size())}, kStableInt64); + auto keyFrameIndicesAccessor = mutableAccessor(keyFrameIndices); for (size_t i = 0; i < keyFrames.size(); ++i) { - keyFrameIndices[i] = keyFrames[i].frameIndex; + keyFrameIndicesAccessor[i] = keyFrames[i].frameIndex; } return keyFrameIndices; @@ -389,29 +449,29 @@ torch::Tensor SingleStreamDecoder::getKeyFrameIndices() { void SingleStreamDecoder::addStream( int streamIndex, AVMediaType mediaType, - const torch::Device& device, + const StableDevice& device, const std::string_view deviceVariant, std::optional ffmpegThreadCount) { - TORCH_CHECK( + STD_TORCH_CHECK( activeStreamIndex_ == NO_ACTIVE_STREAM, "Can only add one single stream."); - TORCH_CHECK( + STD_TORCH_CHECK( mediaType == AVMEDIA_TYPE_VIDEO || mediaType == AVMEDIA_TYPE_AUDIO, "Can only add video or audio streams."); - TORCH_CHECK(formatContext_.get() != nullptr); + STD_TORCH_CHECK(formatContext_.get() != nullptr, "Format context is null"); AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr; activeStreamIndex_ = av_find_best_stream( formatContext_.get(), mediaType, streamIndex, -1, &avCodec, 0); - if (activeStreamIndex_ < 0) { - throw std::invalid_argument( - "No valid stream found in input file. Is " + - std::to_string(streamIndex) + " of the desired media type?"); - } + STD_TORCH_CHECK( + activeStreamIndex_ >= 0, + "No valid stream found in input file. Is ", + streamIndex, + " of the desired media type?"); - TORCH_CHECK(avCodec != nullptr); + STD_TORCH_CHECK(avCodec != nullptr, "Codec not found"); StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; streamInfo.streamIndex = activeStreamIndex_; @@ -420,14 +480,14 @@ void SingleStreamDecoder::addStream( streamInfo.avMediaType = mediaType; // This should never happen, checking just to be safe. - TORCH_CHECK( + STD_TORCH_CHECK( streamInfo.stream->codecpar->codec_type == mediaType, "FFmpeg found stream with index ", activeStreamIndex_, " which is of the wrong media type."); deviceInterface_ = createDeviceInterface(device, deviceVariant); - TORCH_CHECK( + STD_TORCH_CHECK( deviceInterface_ != nullptr, "Failed to create device interface. This should never happen, please report."); @@ -440,12 +500,12 @@ void SingleStreamDecoder::addStream( } AVCodecContext* codecContext = avcodec_alloc_context3(avCodec); - TORCH_CHECK(codecContext != nullptr); + STD_TORCH_CHECK(codecContext != nullptr, "Failed to allocate codec context"); streamInfo.codecContext = makeSharedAVCodecContext(codecContext); int retVal = avcodec_parameters_to_context( streamInfo.codecContext.get(), streamInfo.stream->codecpar); - TORCH_CHECK_EQ(retVal, AVSUCCESS); + STD_TORCH_CHECK(retVal == AVSUCCESS, "avcodec_parameters_to_context failed"); streamInfo.codecContext->thread_count = ffmpegThreadCount.value_or(0); streamInfo.codecContext->pkt_timebase = streamInfo.stream->time_base; @@ -456,7 +516,8 @@ void SingleStreamDecoder::addStream( deviceInterface_->registerHardwareDeviceWithCodec( streamInfo.codecContext.get()); retVal = avcodec_open2(streamInfo.codecContext.get(), avCodec, nullptr); - TORCH_CHECK(retVal >= AVSUCCESS, getFFMPEGErrorStringFromErrorCode(retVal)); + STD_TORCH_CHECK( + retVal >= AVSUCCESS, getFFMPEGErrorStringFromErrorCode(retVal)); streamInfo.codecContext->time_base = streamInfo.stream->time_base; @@ -483,8 +544,8 @@ void SingleStreamDecoder::addVideoStream( std::vector& transforms, const VideoStreamOptions& videoStreamOptions, std::optional customFrameMappings) { - TORCH_CHECK( - transforms.empty() || videoStreamOptions.device == torch::kCPU, + STD_TORCH_CHECK( + transforms.empty() || videoStreamOptions.device == kStableCPU, " Transforms are only supported for CPU devices."); addStream( @@ -498,7 +559,7 @@ void SingleStreamDecoder::addVideoStream( containerMetadata_.allStreamMetadata[activeStreamIndex_]; if (seekMode_ == SeekMode::approximate) { - TORCH_CHECK( + STD_TORCH_CHECK( streamMetadata.averageFpsFromHeader.has_value(), "Seek mode is approximate, but stream ", std::to_string(activeStreamIndex_), @@ -508,30 +569,55 @@ void SingleStreamDecoder::addVideoStream( auto& streamInfo = streamInfos_[activeStreamIndex_]; streamInfo.videoStreamOptions = videoStreamOptions; - streamMetadata.width = streamInfo.codecContext->width; - streamMetadata.height = streamInfo.codecContext->height; - streamMetadata.sampleAspectRatio = - streamInfo.codecContext->sample_aspect_ratio; - if (seekMode_ == SeekMode::custom_frame_mappings) { - TORCH_CHECK( + STD_TORCH_CHECK( customFrameMappings.has_value(), "Missing frame mappings when custom_frame_mappings seek mode is set."); readCustomFrameMappingsUpdateMetadataAndIndex( activeStreamIndex_, customFrameMappings.value()); } - metadataDims_ = - FrameDims(streamMetadata.height.value(), streamMetadata.width.value()); + // Set preRotationDims_ for the active stream. These are the raw encoded + // dimensions from FFmpeg, used as a fallback for tensor pre-allocation when + // no resize/rotation transforms are applied. + preRotationDims_ = FrameDims( + streamInfo.stream->codecpar->height, streamInfo.stream->codecpar->width); + + FrameDims currInputDims = preRotationDims_; + + // If there's rotation, prepend a RotationTransform to handle it in the + // filter graph. This way user transforms (resize, crop) operate in + // post-rotation coordinate space, preserving x/y coordinates for crops. + // + // It is critical to apply the rotation *before* any user-supplied + // transforms. By design, we want: + // A: VideoDecoder(..., transforms=tv_transforms)[i] + // to be equivalent to: + // B: tv_transforms(VideoDecoder(...)[i]) + // In B, rotation is applied before transforms, so A must behave the same. + // + // TODO: benchmark the performance of doing this additional filtergraph + // transform + Rotation rotation = rotationFromDegrees(streamMetadata.rotation); + if (rotation != Rotation::NONE) { + auto rotationTransform = + std::make_unique(rotation, currInputDims); + currInputDims = rotationTransform->getOutputFrameDims().value(); + resizedOutputDims_ = currInputDims; + transforms_.push_back(std::move(rotationTransform)); + } + + // Note that we are claiming ownership of the transform objects passed in to + // us. + // Validate and add user transforms for (auto& transform : transforms) { - TORCH_CHECK(transform != nullptr, "Transforms should never be nullptr!"); + STD_TORCH_CHECK( + transform != nullptr, "Transforms should never be nullptr!"); + transform->validate(currInputDims); if (transform->getOutputFrameDims().has_value()) { resizedOutputDims_ = transform->getOutputFrameDims().value(); + currInputDims = resizedOutputDims_.value(); } - transform->validate(streamMetadata); - - // Note that we are claiming ownership of the transform objects passed in to - // us. transforms_.push_back(std::unique_ptr(transform)); } @@ -542,34 +628,33 @@ void SingleStreamDecoder::addVideoStream( void SingleStreamDecoder::addAudioStream( int streamIndex, const AudioStreamOptions& audioStreamOptions) { - TORCH_CHECK( + STD_TORCH_CHECK( seekMode_ == SeekMode::approximate, "seek_mode must be 'approximate' for audio streams."); if (audioStreamOptions.numChannels.has_value()) { - TORCH_CHECK( - *audioStreamOptions.numChannels > 0 && - *audioStreamOptions.numChannels <= AV_NUM_DATA_POINTERS, - "num_channels must be > 0 and <= AV_NUM_DATA_POINTERS (usually 8). Got: ", + STD_TORCH_CHECK( + *audioStreamOptions.numChannels > 0, + "num_channels must be > 0. Got: ", *audioStreamOptions.numChannels); } - addStream(streamIndex, AVMEDIA_TYPE_AUDIO); + // We hardcode ffmpegThreadCount=1 for audio, see + // https://github.com/pytorch/torchcodec/issues/1253 and + // https://github.com/pytorch/torchcodec/pull/1254 + addStream( + streamIndex, AVMEDIA_TYPE_AUDIO, StableDevice(kStableCPU), "ffmpeg", 1); auto& streamInfo = streamInfos_[activeStreamIndex_]; streamInfo.audioStreamOptions = audioStreamOptions; - auto& streamMetadata = - containerMetadata_.allStreamMetadata[activeStreamIndex_]; - streamMetadata.sampleRate = - static_cast(streamInfo.codecContext->sample_rate); - streamMetadata.numChannels = - static_cast(getNumChannels(streamInfo.codecContext)); - // FFmpeg docs say that the decoder will try to decode natively in this // format, if it can. Docs don't say what the decoder does when it doesn't // support that format, but it looks like it does nothing, so this probably // doesn't hurt. streamInfo.codecContext->request_sample_fmt = AV_SAMPLE_FMT_FLTP; + + // Initialize device interface for audio + deviceInterface_->initializeAudio(audioStreamOptions); } // -------------------------------------------------------------------------- @@ -585,7 +670,7 @@ FrameOutput SingleStreamDecoder::getNextFrame() { } FrameOutput SingleStreamDecoder::getNextFrameInternal( - std::optional preAllocatedOutputTensor) { + std::optional preAllocatedOutputTensor) { validateActiveStream(); UniqueAVFrame avFrame = decodeAVFrame([this](const UniqueAVFrame& avFrame) { return getPtsOrDts(avFrame) >= cursor_; @@ -601,34 +686,42 @@ FrameOutput SingleStreamDecoder::getFrameAtIndex(int64_t frameIndex) { FrameOutput SingleStreamDecoder::getFrameAtIndexInternal( int64_t frameIndex, - std::optional preAllocatedOutputTensor) { + std::optional preAllocatedOutputTensor) { validateActiveStream(AVMEDIA_TYPE_VIDEO); const auto& streamInfo = streamInfos_[activeStreamIndex_]; const auto& streamMetadata = containerMetadata_.allStreamMetadata[activeStreamIndex_]; - std::optional numFrames = getNumFrames(streamMetadata); + std::optional numFrames = streamMetadata.getNumFrames(seekMode_); if (numFrames.has_value()) { // If the frameIndex is negative, we convert it to a positive index frameIndex = frameIndex >= 0 ? frameIndex : frameIndex + numFrames.value(); } validateFrameIndex(streamMetadata, frameIndex); - int64_t pts = getPts(frameIndex); - setCursorPtsInSeconds(ptsToSeconds(pts, streamInfo.timeBase)); - return getNextFrameInternal(preAllocatedOutputTensor); + // Only set cursor if we're not decoding sequentially: when decoding + // sequentially, we don't need to seek anywhere, so by *not* setting the + // cursor we allow canWeAvoidSeeking() to return true early. + if (frameIndex != lastDecodedFrameIndex_ + 1) { + int64_t pts = getPts(frameIndex); + setCursorPtsInSeconds(ptsToSeconds(pts, streamInfo.timeBase)); + } + + auto result = getNextFrameInternal(preAllocatedOutputTensor); + lastDecodedFrameIndex_ = frameIndex; + return result; } FrameBatchOutput SingleStreamDecoder::getFramesAtIndices( - const torch::Tensor& frameIndices) { + const torch::stable::Tensor& frameIndices) { validateActiveStream(AVMEDIA_TYPE_VIDEO); - auto frameIndicesAccessor = frameIndices.accessor(); + auto frameIndicesData = constAccessor(frameIndices); bool indicesAreSorted = true; for (int64_t i = 1; i < frameIndices.numel(); ++i) { - if (frameIndicesAccessor[i] < frameIndicesAccessor[i - 1]) { + if (frameIndicesData[i] < frameIndicesData[i - 1]) { indicesAreSorted = false; break; } @@ -647,37 +740,43 @@ FrameBatchOutput SingleStreamDecoder::getFramesAtIndices( std::sort( argsort.begin(), argsort.end(), - [&frameIndicesAccessor](size_t a, size_t b) { - return frameIndicesAccessor[a] < frameIndicesAccessor[b]; + [&frameIndicesData](size_t a, size_t b) { + return frameIndicesData[a] < frameIndicesData[b]; }); } const auto& streamInfo = streamInfos_[activeStreamIndex_]; const auto& videoStreamOptions = streamInfo.videoStreamOptions; FrameBatchOutput frameBatchOutput( - frameIndices.numel(), - resizedOutputDims_.value_or(metadataDims_), - videoStreamOptions.device); + frameIndices.numel(), getOutputDims(), videoStreamOptions.device); + + auto frameBatchOutputPtsSeconds = + mutableAccessor(frameBatchOutput.ptsSeconds); + auto frameBatchOutputDurationSeconds = + mutableAccessor(frameBatchOutput.durationSeconds); auto previousIndexInVideo = -1; for (int64_t f = 0; f < frameIndices.numel(); ++f) { auto indexInOutput = indicesAreSorted ? f : argsort[f]; - auto indexInVideo = frameIndicesAccessor[indexInOutput]; + auto indexInVideo = frameIndicesData[indexInOutput]; if ((f > 0) && (indexInVideo == previousIndexInVideo)) { // Avoid decoding the same frame twice auto previousIndexInOutput = indicesAreSorted ? f - 1 : argsort[f - 1]; - frameBatchOutput.data[indexInOutput].copy_( - frameBatchOutput.data[previousIndexInOutput]); - frameBatchOutput.ptsSeconds[indexInOutput] = - frameBatchOutput.ptsSeconds[previousIndexInOutput]; - frameBatchOutput.durationSeconds[indexInOutput] = - frameBatchOutput.durationSeconds[previousIndexInOutput]; + copyFrame( + frameBatchOutput.data, + indexInOutput, + frameBatchOutput.data, + previousIndexInOutput); + frameBatchOutputPtsSeconds[indexInOutput] = + frameBatchOutputPtsSeconds[previousIndexInOutput]; + frameBatchOutputDurationSeconds[indexInOutput] = + frameBatchOutputDurationSeconds[previousIndexInOutput]; } else { FrameOutput frameOutput = getFrameAtIndexInternal( - indexInVideo, frameBatchOutput.data[indexInOutput]); - frameBatchOutput.ptsSeconds[indexInOutput] = frameOutput.ptsSeconds; - frameBatchOutput.durationSeconds[indexInOutput] = + indexInVideo, selectRow(frameBatchOutput.data, indexInOutput)); + frameBatchOutputPtsSeconds[indexInOutput] = frameOutput.ptsSeconds; + frameBatchOutputDurationSeconds[indexInOutput] = frameOutput.durationSeconds; } previousIndexInVideo = indexInVideo; @@ -695,16 +794,16 @@ FrameBatchOutput SingleStreamDecoder::getFramesInRange( const auto& streamMetadata = containerMetadata_.allStreamMetadata[activeStreamIndex_]; const auto& streamInfo = streamInfos_[activeStreamIndex_]; - TORCH_CHECK( + STD_TORCH_CHECK( start >= 0, "Range start, " + std::to_string(start) + " is less than 0."); - TORCH_CHECK( + STD_TORCH_CHECK( step > 0, "Step must be greater than 0; is " + std::to_string(step)); // Note that if we do not have the number of frames available in our // metadata, then we assume that the upper part of the range is valid. - std::optional numFrames = getNumFrames(streamMetadata); + std::optional numFrames = streamMetadata.getNumFrames(seekMode_); if (numFrames.has_value()) { - TORCH_CHECK( + STD_TORCH_CHECK( stop <= numFrames.value(), "Range stop, " + std::to_string(stop) + ", is more than the number of frames, " + @@ -714,15 +813,17 @@ FrameBatchOutput SingleStreamDecoder::getFramesInRange( int64_t numOutputFrames = std::ceil((stop - start) / double(step)); const auto& videoStreamOptions = streamInfo.videoStreamOptions; FrameBatchOutput frameBatchOutput( - numOutputFrames, - resizedOutputDims_.value_or(metadataDims_), - videoStreamOptions.device); + numOutputFrames, getOutputDims(), videoStreamOptions.device); + auto frameBatchOutputPtsSeconds = + mutableAccessor(frameBatchOutput.ptsSeconds); + auto frameBatchOutputDurationSeconds = + mutableAccessor(frameBatchOutput.durationSeconds); for (int64_t i = start, f = 0; i < stop; i += step, ++f) { FrameOutput frameOutput = - getFrameAtIndexInternal(i, frameBatchOutput.data[f]); - frameBatchOutput.ptsSeconds[f] = frameOutput.ptsSeconds; - frameBatchOutput.durationSeconds[f] = frameOutput.durationSeconds; + getFrameAtIndexInternal(i, selectRow(frameBatchOutput.data, f)); + frameBatchOutputPtsSeconds[f] = frameOutput.ptsSeconds; + frameBatchOutputDurationSeconds[f] = frameOutput.durationSeconds; } frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data); return frameBatchOutput; @@ -770,28 +871,29 @@ FrameOutput SingleStreamDecoder::getFramePlayedAt(double seconds) { } FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt( - const torch::Tensor& timestamps) { + const torch::stable::Tensor& timestamps) { validateActiveStream(AVMEDIA_TYPE_VIDEO); const auto& streamMetadata = containerMetadata_.allStreamMetadata[activeStreamIndex_]; - double minSeconds = getMinSeconds(streamMetadata); - std::optional maxSeconds = getMaxSeconds(streamMetadata); + double minSeconds = streamMetadata.getBeginStreamSeconds(seekMode_); + std::optional maxSeconds = + streamMetadata.getEndStreamSeconds(seekMode_); // The frame played at timestamp t and the one played at timestamp `t + // eps` are probably the same frame, with the same index. The easiest way to // avoid decoding that unique frame twice is to convert the input timestamps // to indices, and leverage the de-duplication logic of getFramesAtIndices. - torch::Tensor frameIndices = - torch::empty({timestamps.numel()}, torch::kInt64); - auto frameIndicesAccessor = frameIndices.accessor(); - auto timestampsAccessor = timestamps.accessor(); + torch::stable::Tensor frameIndices = + torch::stable::empty({timestamps.numel()}, kStableInt64); + auto frameIndicesAccessor = mutableAccessor(frameIndices); + auto timestampsAccessor = constAccessor(timestamps); for (int64_t i = 0; i < timestamps.numel(); ++i) { auto frameSeconds = timestampsAccessor[i]; - TORCH_CHECK( + STD_TORCH_CHECK( frameSeconds >= minSeconds, "frame pts is " + std::to_string(frameSeconds) + "; must be greater than or equal to " + std::to_string(minSeconds) + @@ -800,7 +902,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt( // Note that if we can't determine the maximum number of seconds from the // metadata, then we assume the frame's pts is valid. if (maxSeconds.has_value()) { - TORCH_CHECK( + STD_TORCH_CHECK( frameSeconds < maxSeconds.value(), "frame pts is " + std::to_string(frameSeconds) + "; must be less than " + std::to_string(maxSeconds.value()) + @@ -815,11 +917,12 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt( FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange( double startSeconds, - double stopSeconds) { + double stopSeconds, + std::optional fps) { validateActiveStream(AVMEDIA_TYPE_VIDEO); const auto& streamMetadata = containerMetadata_.allStreamMetadata[activeStreamIndex_]; - TORCH_CHECK( + STD_TORCH_CHECK( startSeconds <= stopSeconds, "Start seconds (" + std::to_string(startSeconds) + ") must be less than or equal to stop seconds (" + @@ -847,15 +950,13 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange( // below. Hence, we need this special case below. if (startSeconds == stopSeconds) { FrameBatchOutput frameBatchOutput( - 0, - resizedOutputDims_.value_or(metadataDims_), - videoStreamOptions.device); + 0, getOutputDims(), videoStreamOptions.device); frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data); return frameBatchOutput; } - double minSeconds = getMinSeconds(streamMetadata); - TORCH_CHECK( + double minSeconds = streamMetadata.getBeginStreamSeconds(seekMode_); + STD_TORCH_CHECK( startSeconds >= minSeconds, "Start seconds is " + std::to_string(startSeconds) + "; must be greater than or equal to " + std::to_string(minSeconds) + @@ -863,49 +964,96 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange( // Note that if we can't determine the maximum seconds from the metadata, // then we assume upper range is valid. - std::optional maxSeconds = getMaxSeconds(streamMetadata); + std::optional maxSeconds = + streamMetadata.getEndStreamSeconds(seekMode_); if (maxSeconds.has_value()) { - TORCH_CHECK( + STD_TORCH_CHECK( startSeconds < maxSeconds.value(), "Start seconds is " + std::to_string(startSeconds) + "; must be less than " + std::to_string(maxSeconds.value()) + "."); - TORCH_CHECK( + STD_TORCH_CHECK( stopSeconds <= maxSeconds.value(), "Stop seconds (" + std::to_string(stopSeconds) + "; must be less than or equal to " + std::to_string(maxSeconds.value()) + ")."); } - // Note that we look at nextPts for a frame, and not its pts or duration. - // Our abstract player displays frames starting at the pts for that frame - // until the pts for the next frame. There are two consequences: - // - // 1. We ignore the duration for a frame. A frame is played until the - // next frame replaces it. This model is robust to durations being 0 or - // incorrect; our source of truth is the pts for frames. If duration is - // accurate, the nextPts for a frame would be equivalent to pts + - // duration. - // 2. In order to establish if the start of an interval maps to a - // particular frame, we need to figure out if it is ordered after the - // frame's pts, but before the next frames's pts. - - int64_t startFrameIndex = secondsToIndexLowerBound(startSeconds); - int64_t stopFrameIndex = secondsToIndexUpperBound(stopSeconds); - int64_t numFrames = stopFrameIndex - startFrameIndex; + // Resample frames to match the target frame rate + if (fps.has_value()) { + STD_TORCH_CHECK( + fps.value() > 0, + "fps must be positive, got " + std::to_string(fps.value())); - FrameBatchOutput frameBatchOutput( - numFrames, - resizedOutputDims_.value_or(metadataDims_), - videoStreamOptions.device); - for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) { - FrameOutput frameOutput = - getFrameAtIndexInternal(i, frameBatchOutput.data[f]); - frameBatchOutput.ptsSeconds[f] = frameOutput.ptsSeconds; - frameBatchOutput.durationSeconds[f] = frameOutput.durationSeconds; - } - frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data); + // TODO: add an early break if requested fps is the same as the current fps - return frameBatchOutput; + double fpsVal = fps.value(); + double frameDurationSeconds = 1.0 / fpsVal; + + double product = (stopSeconds - startSeconds) * fpsVal; + int64_t numOutputFrames = static_cast(std::round(product)); + + FrameBatchOutput frameBatchOutput( + numOutputFrames, getOutputDims(), videoStreamOptions.device); + + auto frameBatchOutputPtsSeconds = + mutableAccessor(frameBatchOutput.ptsSeconds); + auto frameBatchOutputDurationSeconds = + mutableAccessor(frameBatchOutput.durationSeconds); + + // Decode frames, reusing already-decoded frames for duplicates + int64_t lastDecodedSourceIndex = -1; + + for (int64_t i = 0; i < numOutputFrames; ++i) { + double targetPtsSeconds = startSeconds + i * frameDurationSeconds; + int64_t sourceIdx = secondsToIndexLowerBound(targetPtsSeconds); + + if (sourceIdx == lastDecodedSourceIndex && lastDecodedSourceIndex >= 0) { + copyFrame(frameBatchOutput.data, i, frameBatchOutput.data, i - 1); + } else { + getFrameAtIndexInternal(sourceIdx, selectRow(frameBatchOutput.data, i)); + lastDecodedSourceIndex = sourceIdx; + } + + frameBatchOutputPtsSeconds[i] = targetPtsSeconds; + frameBatchOutputDurationSeconds[i] = frameDurationSeconds; + } + + frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data); + return frameBatchOutput; + } else { + // Note that we look at nextPts for a frame, and not its pts or duration. + // Our abstract player displays frames starting at the pts for that frame + // until the pts for the next frame. There are two consequences: + // + // 1. We ignore the duration for a frame. A frame is played until the + // next frame replaces it. This model is robust to durations being 0 or + // incorrect; our source of truth is the pts for frames. If duration is + // accurate, the nextPts for a frame would be equivalent to pts + + // duration. + // 2. In order to establish if the start of an interval maps to a + // particular frame, we need to figure out if it is ordered after the + // frame's pts, but before the next frames's pts. + + int64_t startFrameIndex = secondsToIndexLowerBound(startSeconds); + int64_t stopFrameIndex = secondsToIndexUpperBound(stopSeconds); + int64_t numFrames = stopFrameIndex - startFrameIndex; + + FrameBatchOutput frameBatchOutput( + numFrames, getOutputDims(), videoStreamOptions.device); + auto frameBatchOutputPtsSeconds = + mutableAccessor(frameBatchOutput.ptsSeconds); + auto frameBatchOutputDurationSeconds = + mutableAccessor(frameBatchOutput.durationSeconds); + for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) { + FrameOutput frameOutput = + getFrameAtIndexInternal(i, selectRow(frameBatchOutput.data, f)); + frameBatchOutputPtsSeconds[f] = frameOutput.ptsSeconds; + frameBatchOutputDurationSeconds[f] = frameOutput.durationSeconds; + } + frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data); + + return frameBatchOutput; + } } // Note [Audio Decoding Design] @@ -967,7 +1115,7 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio( validateActiveStream(AVMEDIA_TYPE_AUDIO); if (stopSecondsOptional.has_value()) { - TORCH_CHECK( + STD_TORCH_CHECK( startSeconds <= *stopSecondsOptional, "Start seconds (" + std::to_string(startSeconds) + ") must be less than or equal to stop seconds (" + @@ -979,7 +1127,7 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio( if (stopSecondsOptional.has_value() && startSeconds == *stopSecondsOptional) { // For consistency with video int numChannels = getNumChannels(streamInfo.codecContext); - return AudioFramesOutput{torch::empty({numChannels, 0}), 0.0}; + return AudioFramesOutput{torch::stable::empty({numChannels, 0}), 0.0}; } auto startPts = secondsToClosestPts(startSeconds, streamInfo.timeBase); @@ -992,7 +1140,7 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio( // TODO-AUDIO Pre-allocate a long-enough tensor instead of creating a vec + // cat(). This would save a copy. We know the duration of the output and the // sample rate, so in theory we know the number of output samples. - std::vector frames; + std::vector frames; std::optional firstFramePtsSeconds = std::nullopt; auto stopPts = stopSecondsOptional.has_value() @@ -1011,7 +1159,7 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio( firstFramePtsSeconds = frameOutput.ptsSeconds; } frames.push_back(frameOutput.data); - } catch (const EndOfFileException& e) { + } catch (const EndOfFileException&) { finished = true; } @@ -1025,7 +1173,7 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio( (stopPts <= lastDecodedAvFrameEnd); } - auto lastSamples = maybeFlushSwrBuffers(); + auto lastSamples = deviceInterface_->maybeFlushAudioBuffers(); if (lastSamples.has_value()) { frames.push_back(*lastSamples); } @@ -1040,7 +1188,7 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio( // stopSecondsOptional, // ") is too low."); - return AudioFramesOutput{torch::cat(frames, 1), *firstFramePtsSeconds}; + return AudioFramesOutput{stableCat(frames, 1), *firstFramePtsSeconds}; } // -------------------------------------------------------------------------- @@ -1060,39 +1208,21 @@ void SingleStreamDecoder::setCursor(int64_t pts) { cursor_ = pts; } -/* -Videos have I frames and non-I frames (P and B frames). Non-I frames need data -from the previous I frame to be decoded. - -Imagine the cursor is at a random frame with PTS=lastDecodedAvFramePts (x for -brevity) and we wish to seek to a user-specified PTS=y. - -If y < x, we don't have a choice but to seek backwards to the highest I frame -before y. - -If y > x, we have two choices: - -1. We could keep decoding forward until we hit y. Illustrated below: - -I P P P I P P P I P P I P P I P - x y - -2. We could try to jump to an I frame between x and y (indicated by j below). -And then start decoding until we encounter y. Illustrated below: - -I P P P I P P P I P P I P P I P - x j y - -(2) is more efficient than (1) if there is an I frame between x and y. -*/ bool SingleStreamDecoder::canWeAvoidSeeking() const { + // Returns true if we can avoid seeking in the AVFormatContext based on + // heuristics that rely on the target cursor_ and the last decoded frame. + // Seeking is expensive, so we try to avoid it when possible. const StreamInfo& streamInfo = streamInfos_.at(activeStreamIndex_); if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) { // For audio, we only need to seek if a backwards seek was requested // within getFramesPlayedInRangeAudio(), when setCursorPtsInSeconds() was // called. For more context, see [Audio Decoding Design] return !cursorWasJustSet_; + } else if (!cursorWasJustSet_) { + // For videos, when decoding consecutive frames, we don't need to seek. + return true; } + if (cursor_ < lastDecodedAvFramePts_) { // We can never skip a seek if we are seeking backwards. return false; @@ -1104,13 +1234,34 @@ bool SingleStreamDecoder::canWeAvoidSeeking() const { // implement caching. return false; } - // We are seeking forwards. - // We can only skip a seek if both lastDecodedAvFramePts and - // cursor_ share the same keyframe. - int lastDecodedAvFrameIndex = getKeyFrameIndexForPts(lastDecodedAvFramePts_); - int targetKeyFrameIndex = getKeyFrameIndexForPts(cursor_); - return lastDecodedAvFrameIndex >= 0 && targetKeyFrameIndex >= 0 && - lastDecodedAvFrameIndex == targetKeyFrameIndex; + // We are seeking forwards. We can skip a seek if both the last decoded frame + // and cursor_ share the same keyframe: + // Videos have I frames and non-I frames (P and B frames). Non-I frames need + // data from the previous I frame to be decoded. + // + // Imagine the cursor is at a random frame with PTS=lastDecodedAvFramePts (x + // for brevity) and we wish to seek to a user-specified PTS=y. + // + // If y < x, we don't have a choice but to seek backwards to the highest I + // frame before y. + // + // If y > x, we have two choices: + // + // 1. We could keep decoding forward until we hit y. Illustrated below: + // + // I P P P I P P P I P P I P + // x y + // + // 2. We could try to jump to an I frame between x and y (indicated by j + // below). And then start decoding until we encounter y. Illustrated below: + // + // I P P P I P P P I P P I P + // x j y + // (2) is only more efficient than (1) if there is an I frame between x and y. + int lastKeyFrame = getKeyFrameIdentifier(lastDecodedAvFramePts_); + int targetKeyFrame = getKeyFrameIdentifier(cursor_); + return lastKeyFrame >= 0 && targetKeyFrame >= 0 && + lastKeyFrame == targetKeyFrame; } // This method looks at currentPts and desiredPts and seeks in the @@ -1147,7 +1298,7 @@ void SingleStreamDecoder::maybeSeekToBeforeDesiredPts() { desiredPts, desiredPts, 0); - TORCH_CHECK( + STD_TORCH_CHECK( status >= 0, "Could not seek file to pts=", std::to_string(desiredPts), @@ -1168,10 +1319,8 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( resetDecodeStats(); - if (cursorWasJustSet_) { - maybeSeekToBeforeDesiredPts(); - cursorWasJustSet_ = false; - } + maybeSeekToBeforeDesiredPts(); + cursorWasJustSet_ = false; UniqueAVFrame avFrame(av_frame_alloc()); AutoAVPacket autoAVPacket; @@ -1219,7 +1368,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( if (status == AVERROR_EOF) { // End of file reached. We must drain the decoder status = deviceInterface_->sendEOFPacket(); - TORCH_CHECK( + STD_TORCH_CHECK( status >= AVSUCCESS, "Could not flush decoder: ", getFFMPEGErrorStringFromErrorCode(status)); @@ -1228,7 +1377,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( break; } - TORCH_CHECK( + STD_TORCH_CHECK( status >= AVSUCCESS, "Could not read frame from input file: ", getFFMPEGErrorStringFromErrorCode(status)); @@ -1244,7 +1393,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( // We got a valid packet. Send it to the decoder, and we'll receive it in // the next iteration. status = deviceInterface_->sendPacket(packet); - TORCH_CHECK( + STD_TORCH_CHECK( status >= AVSUCCESS, "Could not push packet to decoder: ", getFFMPEGErrorStringFromErrorCode(status)); @@ -1258,7 +1407,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( "Requested next frame while there are no more frames left to " "decode."); } - TORCH_CHECK( + STD_TORCH_CHECK( false, "Could not receive frame from decoder: ", getFFMPEGErrorStringFromErrorCode(status)); @@ -1283,22 +1432,17 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( FrameOutput SingleStreamDecoder::convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, - std::optional preAllocatedOutputTensor) { + std::optional preAllocatedOutputTensor) { // Convert the frame to tensor. FrameOutput frameOutput; - auto& streamInfo = streamInfos_[activeStreamIndex_]; frameOutput.ptsSeconds = ptsToSeconds( getPtsOrDts(avFrame), formatContext_->streams[activeStreamIndex_]->time_base); frameOutput.durationSeconds = ptsToSeconds( getDuration(avFrame), formatContext_->streams[activeStreamIndex_]->time_base); - if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) { - convertAudioAVFrameToFrameOutputOnCPU(avFrame, frameOutput); - } else { - deviceInterface_->convertAVFrameToFrameOutput( - avFrame, frameOutput, preAllocatedOutputTensor); - } + deviceInterface_->convertAVFrameToFrameOutput( + avFrame, frameOutput, std::move(preAllocatedOutputTensor)); return frameOutput; } @@ -1432,10 +1576,9 @@ std::optional SingleStreamDecoder::maybeFlushSwrBuffers() { // Returns a [N]CHW *view* of a [N]HWC input tensor, if the options require // so. The [N] leading batch-dimension is optional i.e. the input tensor can -// be 3D or 4D. Calling permute() is guaranteed to return a view as per the -// docs: https://pytorch.org/docs/stable/generated/torch.permute.html -torch::Tensor SingleStreamDecoder::maybePermuteHWC2CHW( - torch::Tensor& hwcTensor) { +// be 3D or 4D. +torch::stable::Tensor SingleStreamDecoder::maybePermuteHWC2CHW( + torch::stable::Tensor& hwcTensor) { if (streamInfos_[activeStreamIndex_].videoStreamOptions.dimensionOrder == "NHWC") { return hwcTensor; @@ -1443,13 +1586,15 @@ torch::Tensor SingleStreamDecoder::maybePermuteHWC2CHW( auto numDimensions = hwcTensor.dim(); auto shape = hwcTensor.sizes(); if (numDimensions == 3) { - TORCH_CHECK(shape[2] == 3, "Not a HWC tensor: ", shape); - return hwcTensor.permute({2, 0, 1}); + STD_TORCH_CHECK( + shape[2] == 3, "Not a HWC tensor: ", intArrayRefToString(shape)); + return stablePermute(hwcTensor, {2, 0, 1}); } else if (numDimensions == 4) { - TORCH_CHECK(shape[3] == 3, "Not a NHWC tensor: ", shape); - return hwcTensor.permute({0, 3, 1, 2}); + STD_TORCH_CHECK( + shape[3] == 3, "Not a NHWC tensor: ", intArrayRefToString(shape)); + return stablePermute(hwcTensor, {0, 3, 1, 2}); } else { - TORCH_CHECK( + STD_TORCH_CHECK( false, "Expected tensor with 3 or 4 dimensions, got ", numDimensions); } } @@ -1458,7 +1603,19 @@ torch::Tensor SingleStreamDecoder::maybePermuteHWC2CHW( // PTS <-> INDEX CONVERSIONS // -------------------------------------------------------------------------- -int SingleStreamDecoder::getKeyFrameIndexForPts(int64_t pts) const { +int SingleStreamDecoder::getKeyFrameIdentifier(int64_t pts) const { + // This function "identifies" a key frame for a given pts value. + // We use the term "identifier" rather than "index" because the nature of the + // index that is returned depends on various factors: + // - If seek_mode is exact, we return the index of the key frame in the + // scanned key-frame vector (streamInfo.keyFrames). So the returned value is + // in [0, num_key_frames). + // - If seek_mode is approximate, we use av_index_search_timestamp() which + // may return a value in [0, num_key_frames) like for mkv, but also a value + // in [0, num_frames) like for mp4. It really depends on the container. + // + // The range of the "identifier" doesn't matter that much, for now we only + // use it to uniquely identify a key frame in canWeAvoidSeeking(). const StreamInfo& streamInfo = streamInfos_.at(activeStreamIndex_); if (streamInfo.keyFrames.empty()) { return av_index_search_timestamp( @@ -1502,13 +1659,16 @@ int64_t SingleStreamDecoder::secondsToIndexLowerBound(double seconds) { case SeekMode::approximate: { auto& streamMetadata = containerMetadata_.allStreamMetadata[activeStreamIndex_]; - TORCH_CHECK( + STD_TORCH_CHECK( streamMetadata.averageFpsFromHeader.has_value(), "Cannot use approximate mode since we couldn't find the average fps from the metadata."); - return std::floor(seconds * streamMetadata.averageFpsFromHeader.value()); + double beginSeconds = streamMetadata.getBeginStreamSeconds(seekMode_); + double relativeSeconds = seconds - beginSeconds; + return std::floor( + relativeSeconds * streamMetadata.averageFpsFromHeader.value()); } default: - TORCH_CHECK(false, "Unknown SeekMode"); + STD_TORCH_CHECK(false, "Unknown SeekMode"); } } @@ -1530,13 +1690,16 @@ int64_t SingleStreamDecoder::secondsToIndexUpperBound(double seconds) { case SeekMode::approximate: { auto& streamMetadata = containerMetadata_.allStreamMetadata[activeStreamIndex_]; - TORCH_CHECK( + STD_TORCH_CHECK( streamMetadata.averageFpsFromHeader.has_value(), "Cannot use approximate mode since we couldn't find the average fps from the metadata."); - return std::ceil(seconds * streamMetadata.averageFpsFromHeader.value()); + double beginSeconds = streamMetadata.getBeginStreamSeconds(seekMode_); + double relativeSeconds = seconds - beginSeconds; + return std::ceil( + relativeSeconds * streamMetadata.averageFpsFromHeader.value()); } default: - TORCH_CHECK(false, "Unknown SeekMode"); + STD_TORCH_CHECK(false, "Unknown SeekMode"); } } @@ -1549,63 +1712,37 @@ int64_t SingleStreamDecoder::getPts(int64_t frameIndex) { case SeekMode::approximate: { auto& streamMetadata = containerMetadata_.allStreamMetadata[activeStreamIndex_]; - TORCH_CHECK( + STD_TORCH_CHECK( streamMetadata.averageFpsFromHeader.has_value(), "Cannot use approximate mode since we couldn't find the average fps from the metadata."); return secondsToClosestPts( - frameIndex / streamMetadata.averageFpsFromHeader.value(), + streamMetadata.getBeginStreamSeconds(seekMode_) + + (frameIndex / streamMetadata.averageFpsFromHeader.value()), streamInfo.timeBase); } default: - TORCH_CHECK(false, "Unknown SeekMode"); + STD_TORCH_CHECK(false, "Unknown SeekMode"); } } +FrameDims SingleStreamDecoder::getOutputDims() const { + const auto& streamMetadata = + containerMetadata_.allStreamMetadata[activeStreamIndex_]; + Rotation rotation = rotationFromDegrees(streamMetadata.rotation); + // If there is a rotation, then resizedOutputDims_ is necessarily non-null + // (the rotation transform would have set it). + if (rotation != Rotation::NONE) { + STD_TORCH_CHECK( + resizedOutputDims_.has_value(), + "Internal error: rotation is applied but resizedOutputDims_ is not set"); + } + return resizedOutputDims_.value_or(preRotationDims_); +} + // -------------------------------------------------------------------------- // STREAM AND METADATA APIS // -------------------------------------------------------------------------- -std::optional SingleStreamDecoder::getNumFrames( - const StreamMetadata& streamMetadata) { - switch (seekMode_) { - case SeekMode::custom_frame_mappings: - case SeekMode::exact: - return streamMetadata.numFramesFromContent.value(); - case SeekMode::approximate: { - return streamMetadata.numFramesFromHeader; - } - default: - TORCH_CHECK(false, "Unknown SeekMode"); - } -} - -double SingleStreamDecoder::getMinSeconds( - const StreamMetadata& streamMetadata) { - switch (seekMode_) { - case SeekMode::custom_frame_mappings: - case SeekMode::exact: - return streamMetadata.beginStreamPtsSecondsFromContent.value(); - case SeekMode::approximate: - return 0; - default: - TORCH_CHECK(false, "Unknown SeekMode"); - } -} - -std::optional SingleStreamDecoder::getMaxSeconds( - const StreamMetadata& streamMetadata) { - switch (seekMode_) { - case SeekMode::custom_frame_mappings: - case SeekMode::exact: - return streamMetadata.endStreamPtsSecondsFromContent.value(); - case SeekMode::approximate: { - return streamMetadata.durationSecondsFromHeader; - } - default: - TORCH_CHECK(false, "Unknown SeekMode"); - } -} - // -------------------------------------------------------------------------- // VALIDATION UTILS // -------------------------------------------------------------------------- @@ -1615,19 +1752,19 @@ void SingleStreamDecoder::validateActiveStream( auto errorMsg = "Provided stream index=" + std::to_string(activeStreamIndex_) + " was not previously added."; - TORCH_CHECK(activeStreamIndex_ != NO_ACTIVE_STREAM, errorMsg); - TORCH_CHECK(streamInfos_.count(activeStreamIndex_) > 0, errorMsg); + STD_TORCH_CHECK(activeStreamIndex_ != NO_ACTIVE_STREAM, errorMsg); + STD_TORCH_CHECK(streamInfos_.count(activeStreamIndex_) > 0, errorMsg); int allStreamMetadataSize = static_cast(containerMetadata_.allStreamMetadata.size()); - TORCH_CHECK( + STD_TORCH_CHECK( activeStreamIndex_ >= 0 && activeStreamIndex_ < allStreamMetadataSize, "Invalid stream index=" + std::to_string(activeStreamIndex_) + "; valid indices are in the range [0, " + std::to_string(allStreamMetadataSize) + ")."); if (avMediaType.has_value()) { - TORCH_CHECK( + STD_TORCH_CHECK( streamInfos_[activeStreamIndex_].avMediaType == avMediaType.value(), "The method you called isn't supported. ", "If you're seeing this error, you are probably trying to call an ", @@ -1636,7 +1773,7 @@ void SingleStreamDecoder::validateActiveStream( } void SingleStreamDecoder::validateScannedAllStreams(const std::string& msg) { - TORCH_CHECK( + STD_TORCH_CHECK( scannedAllStreams_, "Must scan all streams to update metadata before calling ", msg); @@ -1645,24 +1782,22 @@ void SingleStreamDecoder::validateScannedAllStreams(const std::string& msg) { void SingleStreamDecoder::validateFrameIndex( const StreamMetadata& streamMetadata, int64_t frameIndex) { - if (frameIndex < 0) { - throw std::out_of_range( - "Invalid frame index=" + std::to_string(frameIndex) + - " for streamIndex=" + std::to_string(streamMetadata.streamIndex) + - "; negative indices must have an absolute value less than the number of frames, " - "and the number of frames must be known."); - } + STABLE_CHECK_INDEX( + frameIndex >= 0, + "Invalid frame index=" + std::to_string(frameIndex) + + " for streamIndex=" + std::to_string(streamMetadata.streamIndex) + + "; negative indices must have an absolute value less than the number of frames, " + "and the number of frames must be known."); // Note that if we do not have the number of frames available in our // metadata, then we assume that the frameIndex is valid. - std::optional numFrames = getNumFrames(streamMetadata); + std::optional numFrames = streamMetadata.getNumFrames(seekMode_); if (numFrames.has_value()) { - if (frameIndex >= numFrames.value()) { - throw std::out_of_range( - "Invalid frame index=" + std::to_string(frameIndex) + - " for streamIndex=" + std::to_string(streamMetadata.streamIndex) + - "; must be less than " + std::to_string(numFrames.value())); - } + STABLE_CHECK_INDEX( + frameIndex < numFrames.value(), + "Invalid frame index=" + std::to_string(frameIndex) + + " for streamIndex=" + std::to_string(streamMetadata.streamIndex) + + "; must be less than " + std::to_string(numFrames.value())); } } @@ -1706,7 +1841,8 @@ double SingleStreamDecoder::getPtsSecondsForFrame(int64_t frameIndex) { } std::string SingleStreamDecoder::getDeviceInterfaceDetails() const { - TORCH_CHECK(deviceInterface_ != nullptr, "Device interface doesn't exist."); + STD_TORCH_CHECK( + deviceInterface_ != nullptr, "Device interface doesn't exist."); return deviceInterface_->getDetails(); } diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 585f42bba..98c61e775 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -6,33 +6,32 @@ #pragma once -#include #include #include #include #include #include -#include "src/torchcodec/_core/AVIOContextHolder.h" -#include "src/torchcodec/_core/DeviceInterface.h" -#include "src/torchcodec/_core/FFMPEGCommon.h" -#include "src/torchcodec/_core/Frame.h" -#include "src/torchcodec/_core/StreamOptions.h" -#include "src/torchcodec/_core/Transform.h" +#include "AVIOContextHolder.h" +#include "DeviceInterface.h" +#include "FFMPEGCommon.h" +#include "Frame.h" +#include "Metadata.h" +#include "StableABICompat.h" +#include "StreamOptions.h" +#include "Transform.h" namespace facebook::torchcodec { // The SingleStreamDecoder class can be used to decode video frames to Tensors. // Note that SingleStreamDecoder is not thread-safe. // Do not call non-const APIs concurrently on the same object. -class SingleStreamDecoder { +class FORCE_PUBLIC_VISIBILITY SingleStreamDecoder { public: // -------------------------------------------------------------------------- // CONSTRUCTION API // -------------------------------------------------------------------------- - enum class SeekMode { exact, approximate, custom_frame_mappings }; - // Creates a SingleStreamDecoder from the video at videoFilePath. explicit SingleStreamDecoder( const std::string& videoFilePath, @@ -61,9 +60,15 @@ class SingleStreamDecoder { // Returns the metadata for the container. ContainerMetadata getContainerMetadata() const; + // Returns the seek mode of this decoder. + SeekMode getSeekMode() const; + + // Returns the active stream index. Returns -2 if no stream is active. + int getActiveStreamIndex() const; + // Returns the key frame indices as a tensor. The tensor is 1D and contains // int64 values, where each value is the frame index for a key frame. - torch::Tensor getKeyFrameIndices(); + torch::stable::Tensor getKeyFrameIndices(); // FrameMappings is used for the custom_frame_mappings seek mode to store // metadata of frames in a stream. The size of all tensors in this struct must @@ -74,13 +79,13 @@ class SingleStreamDecoder { // -------------------------------------------------------------------------- struct FrameMappings { // 1D tensor of int64, each value is the PTS of a frame in timebase units. - torch::Tensor all_frames; + torch::stable::Tensor all_frames; // 1D tensor of bool, each value indicates if the corresponding frame in // all_frames is a key frame. - torch::Tensor is_key_frame; + torch::stable::Tensor is_key_frame; // 1D tensor of int64, each value is the duration of the corresponding frame // in all_frames in timebase units. - torch::Tensor duration; + torch::stable::Tensor duration; }; void addVideoStream( @@ -109,7 +114,8 @@ class SingleStreamDecoder { // Returns frames at the given indices for a given stream as a single stacked // Tensor. - FrameBatchOutput getFramesAtIndices(const torch::Tensor& frameIndices); + FrameBatchOutput getFramesAtIndices( + const torch::stable::Tensor& frameIndices); // Returns frames within a given range. The range is defined by [start, stop). // The values retrieved from the range are: [start, start+step, @@ -124,7 +130,7 @@ class SingleStreamDecoder { // seconds=5.999, etc. FrameOutput getFramePlayedAt(double seconds); - FrameBatchOutput getFramesPlayedAt(const torch::Tensor& timestamps); + FrameBatchOutput getFramesPlayedAt(const torch::stable::Tensor& timestamps); // Returns frames within a given pts range. The range is defined by // [startSeconds, stopSeconds) with respect to the pts values for frames. The @@ -143,9 +149,13 @@ class SingleStreamDecoder { // Valid values for startSeconds and stopSeconds are: // // [beginStreamPtsSecondsFromContent, endStreamPtsSecondsFromContent) + // + // If fps is specified, frames are resampled to match the target frame + // rate by duplicating or dropping frames as necessary. FrameBatchOutput getFramesPlayedInRange( double startSeconds, - double stopSeconds); + double stopSeconds, + std::optional fps = std::nullopt); AudioFramesOutput getFramesPlayedInRangeAudio( double startSeconds, @@ -167,7 +177,8 @@ class SingleStreamDecoder { // can move it back to private. FrameOutput getFrameAtIndexInternal( int64_t frameIndex, - std::optional preAllocatedOutputTensor = std::nullopt); + std::optional preAllocatedOutputTensor = + std::nullopt); // Exposed for _test_frame_pts_equality, which is used to test non-regression // of pts resolution (64 to 32 bit floats) @@ -261,37 +272,21 @@ class SingleStreamDecoder { std::function filterFunction); FrameOutput getNextFrameInternal( - std::optional preAllocatedOutputTensor = std::nullopt); + std::optional preAllocatedOutputTensor = + std::nullopt); - torch::Tensor maybePermuteHWC2CHW(torch::Tensor& hwcTensor); + torch::stable::Tensor maybePermuteHWC2CHW(torch::stable::Tensor& hwcTensor); FrameOutput convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, - std::optional preAllocatedOutputTensor = std::nullopt); - - void convertAVFrameToFrameOutputOnCPU( - UniqueAVFrame& avFrame, - FrameOutput& frameOutput, - std::optional preAllocatedOutputTensor = std::nullopt); - - void convertAudioAVFrameToFrameOutputOnCPU( - UniqueAVFrame& srcAVFrame, - FrameOutput& frameOutput); - - torch::Tensor convertAVFrameToTensorUsingFilterGraph( - const UniqueAVFrame& avFrame); - - int convertAVFrameToTensorUsingSwsScale( - const UniqueAVFrame& avFrame, - torch::Tensor& outputTensor); - - std::optional maybeFlushSwrBuffers(); + std::optional preAllocatedOutputTensor = + std::nullopt); // -------------------------------------------------------------------------- // PTS <-> INDEX CONVERSIONS // -------------------------------------------------------------------------- - int getKeyFrameIndexForPts(int64_t pts) const; + int getKeyFrameIdentifier(int64_t pts) const; // Returns the key frame index of the presentation timestamp using our index. // We build this index by scanning the file in @@ -306,6 +301,16 @@ class SingleStreamDecoder { int64_t getPts(int64_t frameIndex); + // Returns the output frame dimensions for video frames. + // If resizedOutputDims_ is set (via resize, crop, or rotation transforms), + // returns that. Otherwise, returns preRotationDims_. + // + // Note: if resizedOutputDims_ is null, there is no rotation (the + // rotation transform would have set it), so preRotationDims_ == + // postRotationDims_. This makes it safe to use preRotationDims_ as the + // fallback. + FrameDims getOutputDims() const; + // -------------------------------------------------------------------------- // STREAM AND METADATA APIS // -------------------------------------------------------------------------- @@ -313,7 +318,7 @@ class SingleStreamDecoder { void addStream( int streamIndex, AVMediaType mediaType, - const torch::Device& device = torch::kCPU, + const StableDevice& device = StableDevice(kStableCPU), const std::string_view deviceVariant = "ffmpeg", std::optional ffmpegThreadCount = std::nullopt); @@ -326,10 +331,6 @@ class SingleStreamDecoder { // index. Note that this index may be truncated for some files. int getBestStreamIndex(AVMediaType mediaType); - std::optional getNumFrames(const StreamMetadata& streamMetadata); - double getMinSeconds(const StreamMetadata& streamMetadata); - std::optional getMaxSeconds(const StreamMetadata& streamMetadata); - // -------------------------------------------------------------------------- // VALIDATION UTILS // -------------------------------------------------------------------------- @@ -357,13 +358,18 @@ class SingleStreamDecoder { // pts to the user when they request a frame. int64_t cursor_ = INT64_MIN; bool cursorWasJustSet_ = false; - int64_t lastDecodedAvFramePts_ = 0; + // Initialized to INT64_MIN instead of 0. With 0, canWeAvoidSeeking() could + // incorrectly skip a seek when the internal FFmpeg frame index (used by + // av_index_search_timestamp() in approximate mode) had not yet been built, + // as some formats (mkv, webm) delay building it until the first seek. + // With INT64_MIN, we always seek when retrieving the first frame. This + // means we correctly seek when the first requested frame is far into the + // video, at the cost of an unnecessary (likely cheap) seek when the first + // requested frame is near the start. + // See: https://github.com/meta-pytorch/torchcodec/pull/1259 + int64_t lastDecodedAvFramePts_ = INT64_MIN; int64_t lastDecodedAvFrameDuration_ = 0; - - // Audio only. We cache it for performance. The video equivalents live in - // deviceInterface_. We store swrContext_ here because we only handle audio - // on the CPU. - UniqueSwrContext swrContext_; + int64_t lastDecodedFrameIndex_ = INT64_MIN; // Stores various internal decoding stats. DecodeStats decodeStats_; @@ -380,18 +386,21 @@ class SingleStreamDecoder { // resizedOutputDims_. If resizedOutputDims_ has no value, that means there // are no transforms that change the output frame dimensions. // - // The priority order for output frame dimension is: + // The priority order for output frame dimensions is: // - // 1. resizedOutputDims_; the resize requested by the user always takes - // priority. - // 2. The dimemnsions of the actual decoded AVFrame. This can change + // 1. resizedOutputDims_; the resize requested by the user (or rotation) + // always takes priority. + // 2. The dimensions of the actual decoded AVFrame. This can change // per-decoded frame, and is unknown in SingleStreamDecoder. Only the // DeviceInterface learns it immediately after decoding a raw frame but - // before the color transformation. - // 3. metdataDims_; the dimensions we learned from the metadata. + // before the color conversion. + // 3. preRotationDims_; the raw encoded dimensions from FFmpeg metadata + // (before any rotation is applied). Used as fallback for tensor + // allocation when resizedOutputDims_ is not set, which only happens + // when no rotation is needed, so preRotationDims_ is the correct value. std::vector> transforms_; std::optional resizedOutputDims_; - FrameDims metadataDims_; + FrameDims preRotationDims_; // Whether or not we have already scanned all streams to update the metadata. bool scannedAllStreams_ = false; diff --git a/src/torchcodec/_core/StableABICompat.h b/src/torchcodec/_core/StableABICompat.h new file mode 100644 index 000000000..d0ea62de2 --- /dev/null +++ b/src/torchcodec/_core/StableABICompat.h @@ -0,0 +1,185 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + +// Symbol visibility for the shared library +#ifdef _WIN32 +#define FORCE_PUBLIC_VISIBILITY __declspec(dllexport) +#else +#define FORCE_PUBLIC_VISIBILITY __attribute__((visibility("default"))) +#endif + +// Flag meant to be used for any API that third-party libraries may call. +// It ensures the API symbol is always public. +#ifdef _WIN32 +#define TORCHCODEC_THIRD_PARTY_API +#else +#define TORCHCODEC_THIRD_PARTY_API __attribute__((visibility("default"))) +#endif + +// Index error check - throws std::out_of_range which pybind11 maps to +// IndexError Use this for index validation errors that should raise IndexError +// in Python +#define STABLE_CHECK_INDEX(cond, msg) \ + do { \ + if (!(cond)) { \ + throw std::out_of_range(std::string(msg)); \ + } \ + } while (false) + +namespace facebook::torchcodec { + +// Device types +using StableDevice = torch::stable::Device; +using StableDeviceType = torch::headeronly::DeviceType; + +// DeviceGuard for CUDA context management +using StableDeviceGuard = torch::stable::accelerator::DeviceGuard; + +// Device type constants +constexpr auto kStableCPU = torch::headeronly::DeviceType::CPU; +constexpr auto kStableCUDA = torch::headeronly::DeviceType::CUDA; +constexpr auto kStableXPU = torch::headeronly::DeviceType::XPU; + +// Scalar type constants +constexpr auto kStableUInt8 = torch::headeronly::ScalarType::Byte; +constexpr auto kStableInt64 = torch::headeronly::ScalarType::Long; +constexpr auto kStableFloat32 = torch::headeronly::ScalarType::Float; +constexpr auto kStableFloat64 = torch::headeronly::ScalarType::Double; +constexpr auto kStableBool = torch::headeronly::ScalarType::Bool; + +// Note: the magic use of torch_call_dispatcher is what is officially +// recommended +// https://github.com/pytorch/pytorch/blob/89f3759429b96a8693b698f013990240bb4e25b3/docs/source/notes/libtorch_stable_abi.md?plain=1#L221 +// It allows us to make an ABI-stable call to an op that isn't officially in the +// stable ABI. Some of these ops currently include permute, rot90, etc. +// See also how xformers relies on it: +// https://github.com/facebookresearch/xformers/blob/720adff2b021f6f43957718514f5be3d10e36fb1/xformers/csrc/pt_stable_utils.h#L85 + +// TODO_STABLE_ABI: upstream? +inline torch::stable::Tensor stablePermute( + const torch::stable::Tensor& self, + std::vector dims) { + const auto num_args = 2; + std::array stack{ + torch::stable::detail::from(self), torch::stable::detail::from(dims)}; + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::permute", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// TODO_STABLE_ABI: upstream? +inline torch::stable::Tensor stableCat( + const std::vector& tensors, + int64_t dim) { + const auto num_args = 2; + std::array stack{ + torch::stable::detail::from(tensors), torch::stable::detail::from(dim)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::cat", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// TODO_STABLE_ABI: upstream? +inline torch::stable::Tensor stableRot90( + const torch::stable::Tensor& self, + int k, + int64_t dim0, + int64_t dim1) { + const auto num_args = 3; + std::array stack{ + torch::stable::detail::from(self), + torch::stable::detail::from(static_cast(k)), + torch::stable::detail::from(std::vector{dim0, dim1})}; + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::rot90", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Shorthand for torch::stable::select(tensor, 0, index), i.e. tensor[index]. +inline torch::stable::Tensor selectRow( + const torch::stable::Tensor& tensor, + int64_t index) { + return torch::stable::select(tensor, 0, index); +} + +template +torch::headeronly::HeaderOnlyTensorAccessor mutableAccessor( + torch::stable::Tensor& tensor) { + return torch::headeronly::HeaderOnlyTensorAccessor( + tensor.mutable_data_ptr(), + tensor.sizes().data(), + tensor.strides().data()); +} + +template +torch::headeronly::HeaderOnlyTensorAccessor constAccessor( + const torch::stable::Tensor& tensor) { + return torch::headeronly::HeaderOnlyTensorAccessor( + tensor.const_data_ptr(), + tensor.sizes().data(), + tensor.strides().data()); +} + +// Copy row srcIndex from srcTensor into row dstIndex of dstTensor. +inline void copyFrame( + torch::stable::Tensor& dstTensor, + int64_t dstIndex, + const torch::stable::Tensor& srcTensor, + int64_t srcIndex) { + auto dst = selectRow(dstTensor, dstIndex); + torch::stable::copy_(dst, selectRow(srcTensor, srcIndex)); +} + +// TODO_STABLE_ABI: this should probably be natively supported by torch::stable. +// Consider upstreaming. +inline const char* deviceTypeName(StableDeviceType deviceType) { + switch (deviceType) { + case kStableCPU: + return "cpu"; + case kStableCUDA: + return "cuda"; + case kStableXPU: + return "xpu"; + default: + return "unknown"; + } +} + +// TODO_STABLE_ABI: This is needed to properly print shape info in error +// messages. There should probably be a better native way to support it, e.g. +// torch::headeronly::IntHeaderOnlyArrayRef probably needs to support the `<<` +// operator. Consider upstreaming. +inline std::string intArrayRefToString( + torch::headeronly::IntHeaderOnlyArrayRef arr) { + std::ostringstream ss; + ss << "["; + for (size_t i = 0; i < arr.size(); ++i) { + if (i > 0) + ss << ", "; + ss << arr[i]; + } + ss << "]"; + return ss.str(); +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/StreamOptions.h b/src/torchcodec/_core/StreamOptions.h index e5ab256e1..6cab3c8e8 100644 --- a/src/torchcodec/_core/StreamOptions.h +++ b/src/torchcodec/_core/StreamOptions.h @@ -6,10 +6,11 @@ #pragma once -#include +#include #include #include #include +#include "StableABICompat.h" namespace facebook::torchcodec { @@ -40,14 +41,20 @@ struct VideoStreamOptions { ColorConversionLibrary::FILTERGRAPH; // By default we use CPU for decoding for both C++ and python users. - torch::Device device = torch::kCPU; + // Note: This is not used for video encoding, because device is determined by + // the device of the input frame tensor. + StableDevice device = StableDevice(kStableCPU); // Device variant (e.g., "ffmpeg", "beta", etc.) std::string_view deviceVariant = "ffmpeg"; // Encoding options - // TODO-VideoEncoder: Consider adding other optional fields here - // (bit rate, gop size, max b frames, preset) - std::optional crf; + std::optional codec; + // Optional pixel format for video encoding (e.g., "yuv420p", "yuv444p") + // If not specified, uses codec's default format. + std::optional pixelFormat; + std::optional crf; + std::optional preset; + std::optional> extraOptions; }; struct AudioStreamOptions { diff --git a/src/torchcodec/_core/SwScale.cpp b/src/torchcodec/_core/SwScale.cpp new file mode 100644 index 000000000..ca748f170 --- /dev/null +++ b/src/torchcodec/_core/SwScale.cpp @@ -0,0 +1,121 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "SwScale.h" +#include "Frame.h" + +namespace facebook::torchcodec { + +SwScale::SwScale(const SwsConfig& config, int swsFlags) + : config_(config), swsFlags_(swsFlags) { + needsResize_ = + (config_.inputHeight != config_.outputHeight || + config_.inputWidth != config_.outputWidth); + + // Create color conversion context (input format -> RGB24). + // Color conversion always outputs at the input resolution. + // When no resize is needed, input and output resolutions are the same. + SwsConfig colorConversionFrameConfig( + config_.inputWidth, + config_.inputHeight, + config_.inputFormat, + config_.inputColorspace, + config_.inputWidth, + config_.inputHeight); + + colorConversionSwsContext_ = createSwsContext( + colorConversionFrameConfig, + // See [Transform and Format Conversion Order] for more on the output + // pixel format. + /*outputFormat=*/AV_PIX_FMT_RGB24, + // No flags for color conversion. When resizing is needed, we use a + // separate swscale context with the appropriate resize flags. + /*swsFlags=*/0); + + // Create resize context if needed (RGB24 at input resolution -> RGB24 at + // output resolution). + if (needsResize_) { + SwsConfig resizeFrameConfig( + config_.inputWidth, + config_.inputHeight, + AV_PIX_FMT_RGB24, + AVCOL_SPC_RGB, + config_.outputWidth, + config_.outputHeight); + + resizeSwsContext_ = createSwsContext( + resizeFrameConfig, + /*outputFormat=*/AV_PIX_FMT_RGB24, + /*swsFlags=*/swsFlags_); + } +} + +int SwScale::convert( + const UniqueAVFrame& avFrame, + torch::stable::Tensor& outputTensor) { + // When resizing is needed, we do sws_scale twice: first convert to RGB24 at + // original resolution, then resize in RGB24 space. This ensures transforms + // happen in the output color space (RGB24) rather than the input color space + // (YUV). + // + // When no resize is needed, we do color conversion directly into the output + // tensor. + torch::stable::Tensor colorConvertedTensor = needsResize_ + ? allocateEmptyHWCTensor( + FrameDims(config_.inputHeight, config_.inputWidth), kStableCPU) + : outputTensor; + + uint8_t* colorConvertedPointers[4] = { + colorConvertedTensor.mutable_data_ptr(), + nullptr, + nullptr, + nullptr}; + int colorConvertedWidth = static_cast(colorConvertedTensor.sizes()[1]); + int colorConvertedLinesizes[4] = {colorConvertedWidth * 3, 0, 0, 0}; + + int colorConvertedHeight = sws_scale( + colorConversionSwsContext_.get(), + avFrame->data, + avFrame->linesize, + 0, + avFrame->height, + colorConvertedPointers, + colorConvertedLinesizes); + + STD_TORCH_CHECK( + colorConvertedHeight == avFrame->height, + "Color conversion swscale pass failed: colorConvertedHeight != avFrame->height: ", + colorConvertedHeight, + " != ", + avFrame->height); + + if (needsResize_) { + uint8_t* srcPointers[4] = { + colorConvertedTensor.mutable_data_ptr(), + nullptr, + nullptr, + nullptr}; + int srcLinesizes[4] = {config_.inputWidth * 3, 0, 0, 0}; + + uint8_t* dstPointers[4] = { + outputTensor.mutable_data_ptr(), nullptr, nullptr, nullptr}; + int expectedOutputWidth = static_cast(outputTensor.sizes()[1]); + int dstLinesizes[4] = {expectedOutputWidth * 3, 0, 0, 0}; + + colorConvertedHeight = sws_scale( + resizeSwsContext_.get(), + srcPointers, + srcLinesizes, + 0, + config_.inputHeight, + dstPointers, + dstLinesizes); + } + + return colorConvertedHeight; +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/SwScale.h b/src/torchcodec/_core/SwScale.h new file mode 100644 index 000000000..8dd994365 --- /dev/null +++ b/src/torchcodec/_core/SwScale.h @@ -0,0 +1,50 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include "FFMPEGCommon.h" +#include "StableABICompat.h" + +namespace facebook::torchcodec { + +struct FrameDims; + +// SwScale uses a double swscale path: +// 1. Color conversion (e.g., YUV -> RGB24) at the original frame resolution +// 2. Resize in RGB24 space (if resizing is needed) +// +// This approach ensures that transforms happen in the output color space +// (RGB24) rather than the input color space (YUV). +// +// The caller is responsible for caching SwScale instances and recreating them +// when the context changes, similar to how FilterGraph is managed. +class SwScale { + public: + SwScale(const SwsConfig& config, int swsFlags = SWS_BILINEAR); + + int convert( + const UniqueAVFrame& avFrame, + torch::stable::Tensor& outputTensor); + + const SwsConfig& getConfig() const { + return config_; + } + + private: + SwsConfig config_; + int swsFlags_; + bool needsResize_; + + // Color conversion context (input format -> RGB24 at original resolution). + UniqueSwsContext colorConversionSwsContext_; + + // Resize context (RGB24 -> RGB24 at output resolution). + // May be null if no resize is needed. + UniqueSwsContext resizeSwsContext_; +}; + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/Transform.cpp b/src/torchcodec/_core/Transform.cpp index 6083986e1..524e7a8c7 100644 --- a/src/torchcodec/_core/Transform.cpp +++ b/src/torchcodec/_core/Transform.cpp @@ -4,9 +4,9 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. -#include "src/torchcodec/_core/Transform.h" -#include -#include "src/torchcodec/_core/FFMPEGCommon.h" +#include "Transform.h" +#include "FFMPEGCommon.h" +#include "StableABICompat.h" namespace facebook::torchcodec { @@ -18,7 +18,7 @@ std::string toFilterGraphInterpolation( case ResizeTransform::InterpolationMode::BILINEAR: return "bilinear"; default: - TORCH_CHECK( + STD_TORCH_CHECK( false, "Unknown interpolation mode: " + std::to_string(static_cast(mode))); @@ -30,7 +30,7 @@ int toSwsInterpolation(ResizeTransform::InterpolationMode mode) { case ResizeTransform::InterpolationMode::BILINEAR: return SWS_BILINEAR; default: - TORCH_CHECK( + STD_TORCH_CHECK( false, "Unknown interpolation mode: " + std::to_string(static_cast(mode))); @@ -42,7 +42,7 @@ int toSwsInterpolation(ResizeTransform::InterpolationMode mode) { std::string ResizeTransform::getFilterGraphCpu() const { return "scale=" + std::to_string(outputDims_.width) + ":" + std::to_string(outputDims_.height) + - ":sws_flags=" + toFilterGraphInterpolation(interpolationMode_); + ":flags=" + toFilterGraphInterpolation(interpolationMode_); } std::optional ResizeTransform::getOutputFrameDims() const { @@ -57,31 +57,128 @@ int ResizeTransform::getSwsFlags() const { return toSwsInterpolation(interpolationMode_); } +CropTransform::CropTransform(const FrameDims& dims) : outputDims_(dims) {} + CropTransform::CropTransform(const FrameDims& dims, int x, int y) : outputDims_(dims), x_(x), y_(y) { - TORCH_CHECK(x_ >= 0, "Crop x position must be >= 0, got: ", x_); - TORCH_CHECK(y_ >= 0, "Crop y position must be >= 0, got: ", y_); + STD_TORCH_CHECK(x >= 0, "Crop x position must be >= 0, got: ", x); + STD_TORCH_CHECK(y >= 0, "Crop y position must be >= 0, got: ", y); } std::string CropTransform::getFilterGraphCpu() const { + // For the FFmpeg filter crop, if the x and y coordinates are left + // unspecified, it defaults to a center crop. + std::string coordinates = x_.has_value() + ? (":" + std::to_string(x_.value()) + ":" + std::to_string(y_.value())) + : ""; return "crop=" + std::to_string(outputDims_.width) + ":" + - std::to_string(outputDims_.height) + ":" + std::to_string(x_) + ":" + - std::to_string(y_) + ":exact=1"; + std::to_string(outputDims_.height) + coordinates + ":exact=1"; } std::optional CropTransform::getOutputFrameDims() const { return outputDims_; } -void CropTransform::validate(const StreamMetadata& streamMetadata) const { - TORCH_CHECK(x_ <= streamMetadata.width, "Crop x position out of bounds"); - TORCH_CHECK( - x_ + outputDims_.width <= streamMetadata.width, - "Crop x position out of bounds") - TORCH_CHECK(y_ <= streamMetadata.height, "Crop y position out of bounds"); - TORCH_CHECK( - y_ + outputDims_.height <= streamMetadata.height, - "Crop y position out of bounds"); +void CropTransform::validate(const FrameDims& inputDims) const { + STD_TORCH_CHECK( + outputDims_.height <= inputDims.height, + "Crop output height (", + outputDims_.height, + ") is greater than input height (", + inputDims.height, + ")"); + STD_TORCH_CHECK( + outputDims_.width <= inputDims.width, + "Crop output width (", + outputDims_.width, + ") is greater than input width (", + inputDims.width, + ")"); + STD_TORCH_CHECK( + x_.has_value() == y_.has_value(), + "Crop x and y values must be both set or both unset"); + if (x_.has_value()) { + STD_TORCH_CHECK( + x_.value() <= inputDims.width, + "Crop x start position, ", + x_.value(), + ", out of bounds of input width, ", + inputDims.width); + STD_TORCH_CHECK( + x_.value() + outputDims_.width <= inputDims.width, + "Crop x end position, ", + x_.value() + outputDims_.width, + ", out of bounds of input width ", + inputDims.width); + STD_TORCH_CHECK( + y_.value() <= inputDims.height, + "Crop y start position, ", + y_.value(), + ", out of bounds of input height, ", + inputDims.height); + STD_TORCH_CHECK( + y_.value() + outputDims_.height <= inputDims.height, + "Crop y end position, ", + y_.value() + outputDims_.height, + ", out of bounds of input height ", + inputDims.height); + } +} + +Rotation rotationFromDegrees(std::optional degrees) { + if (!degrees.has_value()) { + return Rotation::NONE; + } + // Round to nearest multiple of 90 degrees + int rounded = static_cast(std::round(*degrees / 90.0)) * 90; + switch (rounded) { + case 0: + return Rotation::NONE; + case 90: + return Rotation::CCW90; + case -90: + return Rotation::CW90; + case 180: + case -180: + return Rotation::ROTATE180; + default: + STD_TORCH_CHECK( + false, + "Unexpected rotation value: ", + *degrees, + ". Expected range is [-180, 180]."); + } +} + +RotationTransform::RotationTransform( + Rotation rotation, + const FrameDims& inputDims) + : rotation_(rotation) { + // 90° rotations swap dimensions + if (rotation_ == Rotation::CCW90 || rotation_ == Rotation::CW90) { + outputDims_ = FrameDims(inputDims.width, inputDims.height); + } else { + outputDims_ = inputDims; + } +} + +std::string RotationTransform::getFilterGraphCpu() const { + switch (rotation_) { + case Rotation::NONE: + return ""; + case Rotation::CCW90: + return "transpose=cclock"; + case Rotation::CW90: + return "transpose=clock"; + case Rotation::ROTATE180: + return "hflip,vflip"; + default: + return ""; + } +} + +std::optional RotationTransform::getOutputFrameDims() const { + return outputDims_; } } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/Transform.h b/src/torchcodec/_core/Transform.h index 28d8c28a2..49cd3213e 100644 --- a/src/torchcodec/_core/Transform.h +++ b/src/torchcodec/_core/Transform.h @@ -8,8 +8,8 @@ #include #include -#include "src/torchcodec/_core/Frame.h" -#include "src/torchcodec/_core/Metadata.h" +#include "Frame.h" +#include "Metadata.h" namespace facebook::torchcodec { @@ -29,8 +29,8 @@ class Transform { return std::nullopt; } - // The ResizeTransform is special, because it is the only transform that - // swscale can handle. + // The ResizeTransform is special because it is the only transform + // that swscale can handle. virtual bool isResize() const { return false; } @@ -42,15 +42,14 @@ class Transform { // // Note that the validation function does not return anything. We expect // invalid configurations to throw an exception. - virtual void validate( - [[maybe_unused]] const StreamMetadata& streamMetadata) const {} + virtual void validate([[maybe_unused]] const FrameDims& inputDims) const {} }; class ResizeTransform : public Transform { public: enum class InterpolationMode { BILINEAR }; - ResizeTransform(const FrameDims& dims) + explicit ResizeTransform(const FrameDims& dims) : outputDims_(dims), interpolationMode_(InterpolationMode::BILINEAR) {} ResizeTransform(const FrameDims& dims, InterpolationMode interpolationMode) @@ -71,14 +70,49 @@ class CropTransform : public Transform { public: CropTransform(const FrameDims& dims, int x, int y); + // Becomes a center crop if x and y are not specified. + explicit CropTransform(const FrameDims& dims); + + std::string getFilterGraphCpu() const override; + std::optional getOutputFrameDims() const override; + void validate(const FrameDims& inputDims) const override; + + private: + FrameDims outputDims_; + std::optional x_; + std::optional y_; +}; + +// Rotation values for RotationTransform. +// These correspond to video metadata rotation angles. +enum class Rotation { + NONE, // 0° + CCW90, // 90° counter-clockwise + CW90, // 90° clockwise (or -90°) + ROTATE180 // 180° (or -180°) +}; + +// Converts rotation degrees from video metadata to Rotation enum. +// Input is expected in the range [-180, 180]. +// Rounds to nearest multiple of 90 degrees before converting. +// Returns Rotation::NONE for nullopt. +Rotation rotationFromDegrees(std::optional degrees); + +// Applies rotation in multiples of 90 degrees using FFmpeg's transpose/flip +// filters. Note: this does not support arbitrary angle rotation +// like TorchVision's RandomRotation transform. +// Handles rotation in the filter graph so that user transforms +// operate in post-rotation coordinate space. +class RotationTransform : public Transform { + public: + RotationTransform(Rotation rotation, const FrameDims& inputDims); + std::string getFilterGraphCpu() const override; std::optional getOutputFrameDims() const override; - void validate(const StreamMetadata& streamMetadata) const override; private: + Rotation rotation_; FrameDims outputDims_; - int x_; - int y_; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/ValidationUtils.cpp b/src/torchcodec/_core/ValidationUtils.cpp index fae3dd940..b2d34a7d7 100644 --- a/src/torchcodec/_core/ValidationUtils.cpp +++ b/src/torchcodec/_core/ValidationUtils.cpp @@ -4,14 +4,14 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. -#include "src/torchcodec/_core/ValidationUtils.h" +#include "ValidationUtils.h" #include -#include "c10/util/Exception.h" +#include "StableABICompat.h" namespace facebook::torchcodec { int validateInt64ToInt(int64_t value, const std::string& parameterName) { - TORCH_CHECK( + STD_TORCH_CHECK( value >= std::numeric_limits::min() && value <= std::numeric_limits::max(), parameterName, @@ -32,4 +32,21 @@ std::optional validateOptionalInt64ToInt( } } +std::streampos validateUint64ToStreampos( + uint64_t value, + const std::string& parameterName) { + // We validate against streamoff limits because streampos + // (std::fpos) stores the actual position as streamoff internally. + // https://en.cppreference.com/w/cpp/io/fpos.html + STD_TORCH_CHECK( + value <= + static_cast(std::numeric_limits::max()), + parameterName, + "=", + value, + " is out of range for streampos type."); + + return static_cast(value); +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/ValidationUtils.h b/src/torchcodec/_core/ValidationUtils.h index ce2d11256..c6c1e670a 100644 --- a/src/torchcodec/_core/ValidationUtils.h +++ b/src/torchcodec/_core/ValidationUtils.h @@ -18,4 +18,8 @@ std::optional validateOptionalInt64ToInt( const std::optional& value, const std::string& parameterName); +std::streampos validateUint64ToStreampos( + uint64_t value, + const std::string& parameterName); + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/WavDecoder.cpp b/src/torchcodec/_core/WavDecoder.cpp new file mode 100644 index 000000000..91b732bb3 --- /dev/null +++ b/src/torchcodec/_core/WavDecoder.cpp @@ -0,0 +1,257 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "WavDecoder.h" + +#include +#include +#include +#include +#include +#include "ValidationUtils.h" + +namespace facebook::torchcodec { +namespace { + +constexpr uint32_t RIFF_HEADER_SIZE = 12; // "RIFF" + fileSize + "WAVE" +constexpr uint32_t CHUNK_HEADER_SIZE = 8; // chunkID + chunkSize +// Standard WAV fmt chunk is at least 16 bytes: +// audioFormat(2) + numChannels(2) + sampleRate(4) + byteRate(4) + blockAlign(2) +// + bitsPerSample(2) +constexpr uint32_t MIN_FMT_CHUNK_SIZE = 16; +// WAVE_FORMAT_EXTENSIBLE adds to the standard WAV fmt chunk: cbSize(2) + +// wValidBitsPerSample(2) + dwChannelMask(4) + SubFormat GUID(16) = 24 more +// bytes, total 40 +constexpr uint32_t MIN_WAVEX_FMT_CHUNK_SIZE = 40; +// Arbitrary max for fmt chunk allocation - set to 5x extended format size +constexpr uint32_t MAX_FMT_CHUNK_SIZE = 200; + +// See standard format codes and Wav file format used in WavHeader: +// https://www.mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html +constexpr uint16_t WAV_FORMAT_PCM = 1; +constexpr uint16_t WAV_FORMAT_EXTENSIBLE = 0xFFFE; + +bool isLittleEndian() { + uint32_t x = 1; + uint8_t firstByte; + std::memcpy(&firstByte, &x, 1); + return firstByte == 1; +} + +template +T readValue(const Container& data, size_t offset) { + static_assert(std::is_trivially_copyable_v); + static_assert( + sizeof(typename Container::value_type) == 1, + "Container value_type must be a 1-byte type for safe byte access"); + STD_TORCH_CHECK( + data.size() >= sizeof(T) && offset <= data.size() - sizeof(T), + "Reading ", + sizeof(T), + " bytes at offset ", + offset, + ": exceeds buffer length ", + data.size()); + T value; + std::memcpy(&value, data.data() + offset, sizeof(T)); + return value; +} + +bool matchesFourCC( + const uint8_t* data, + size_t dataSize, + size_t offset, + std::string_view expected) { + STD_TORCH_CHECK( + dataSize >= 4 && offset <= dataSize - 4, + "Data array too small for FourCC comparison at offset ", + offset); + return std::memcmp(data + offset, expected.data(), 4) == 0; +} + +template +void safeReadFile(std::ifstream& file, Container& buffer, size_t bytesToRead) { + static_assert( + sizeof(typename Container::value_type) == 1, + "Container value_type must be a 1-byte type for safe reinterpret_cast to char*"); + STD_TORCH_CHECK( + bytesToRead <= buffer.size(), "Read size exceeds buffer length"); + file.read( + reinterpret_cast(buffer.data()), + static_cast(bytesToRead)); + STD_TORCH_CHECK( + !file.fail() && + file.gcount() == static_cast(bytesToRead), + "WAV: unexpected end of data (expected ", + bytesToRead, + " bytes, got ", + file.gcount(), + ")"); +} + +void safeSeek( + std::ifstream& file, + std::streampos pos, + std::ios_base::seekdir whence = std::ios::beg) { + file.seekg(pos, whence); + STD_TORCH_CHECK(!file.fail(), "Failed to seek to ", pos, " in WAV file"); +} + +} // namespace + +WavDecoder::WavDecoder(const std::string& path) + : file_(path, std::ios::binary) { + // TODO WavDecoder: Support big-endian host machines + STD_TORCH_CHECK( + isLittleEndian(), "WAV decoder requires little-endian architecture"); + STD_TORCH_CHECK(file_.is_open(), "Failed to open WAV file: ", path); + + uint64_t fileSize; + try { + fileSize = std::filesystem::file_size(path); + } catch (const std::filesystem::filesystem_error& e) { + STD_TORCH_CHECK( + false, "Failed to get file size for: ", path, ". Error: ", e.what()); + } + parseHeader(fileSize); + validateHeader(); +} + +void WavDecoder::parseHeader(uint64_t fileSize) { + safeSeek(file_, 0, std::ios::beg); + + std::array riffHeader; + safeReadFile(file_, riffHeader, RIFF_HEADER_SIZE); + + STD_TORCH_CHECK( + matchesFourCC(riffHeader.data(), riffHeader.size(), 0, "RIFF"), + "Missing RIFF header"); + STD_TORCH_CHECK( + matchesFourCC(riffHeader.data(), riffHeader.size(), 8, "WAVE"), + "Missing WAVE format identifier"); + + ChunkInfo fmtChunk = + findChunk("fmt ", static_cast(RIFF_HEADER_SIZE), fileSize); + STD_TORCH_CHECK( + fmtChunk.size >= MIN_FMT_CHUNK_SIZE, + "Invalid fmt chunk: size must be at least ", + MIN_FMT_CHUNK_SIZE, + " bytes"); + + // Use ChunkInfo to seek to and read the fmt chunk data + safeSeek( + file_, + validateUint64ToStreampos(fmtChunk.offset, "fmtChunk.offset"), + std::ios::beg); + STD_TORCH_CHECK( + fmtChunk.size <= MAX_FMT_CHUNK_SIZE, + "fmt chunk too large for allocation: ", + fmtChunk.size, + " bytes, maximum allowed is ", + MAX_FMT_CHUNK_SIZE, + " bytes"); + std::vector fmtData(fmtChunk.size); + safeReadFile(file_, fmtData, fmtChunk.size); + + header_.audioFormat = readValue(fmtData, 0); + header_.numChannels = readValue(fmtData, 2); + header_.sampleRate = readValue(fmtData, 4); + header_.bitsPerSample = readValue(fmtData, 14); + + if (header_.audioFormat == WAV_FORMAT_EXTENSIBLE) { + STD_TORCH_CHECK( + fmtChunk.size >= MIN_WAVEX_FMT_CHUNK_SIZE, + "WAVE_FORMAT_EXTENSIBLE fmt chunk too small"); + header_.subFormat = readValue(fmtData, 24); + } + + ChunkInfo dataChunk = findChunk("data", RIFF_HEADER_SIZE, fileSize); + header_.dataSize = dataChunk.size; +} + +void WavDecoder::validateHeader() { + uint16_t effectiveFormat = (header_.audioFormat == WAV_FORMAT_EXTENSIBLE) + ? header_.subFormat + : header_.audioFormat; + // TODO WavDecoder: Support WAV_FORMAT_IEEE_FLOAT 32, 64 bit + STD_TORCH_CHECK( + effectiveFormat == WAV_FORMAT_PCM, + "Unsupported WAV format: ", + effectiveFormat, + ". Only PCM format is supported."); + + // TODO WavDecoder: support 8, 16, 24 bits + STD_TORCH_CHECK( + effectiveFormat != WAV_FORMAT_PCM || header_.bitsPerSample == 32, + "Unsupported PCM bit depth: ", + header_.bitsPerSample, + ". Currently supported bit depths are: 32"); + + STD_TORCH_CHECK(header_.numChannels > 0, "Invalid WAV: zero channels"); + STD_TORCH_CHECK(header_.sampleRate > 0, "Invalid WAV: zero sample rate"); + + if (effectiveFormat == WAV_FORMAT_PCM && header_.bitsPerSample == 32) { + sampleFormat_ = "s32"; + codecName_ = "pcm_s32le"; + } else { + STD_TORCH_CHECK( + false, + "Unsupported format after validation. That's unexpected, please report this to the TorchCodec repo."); + } +} + +// Given a chunkId, read through each chunk until we find a match, then return +// its offset and size. +WavDecoder::ChunkInfo WavDecoder::findChunk( + std::string_view chunkId, + uint64_t startPos, + uint64_t fileSize) { + if (fileSize < CHUNK_HEADER_SIZE) { + STD_TORCH_CHECK(false, "File too small to contain chunk:", chunkId); + } + while (startPos <= fileSize - CHUNK_HEADER_SIZE) { + safeSeek( + file_, validateUint64ToStreampos(startPos, "startPos"), std::ios::beg); + + std::array chunkHeader; + safeReadFile(file_, chunkHeader, CHUNK_HEADER_SIZE); + // Read chunk size which immediately follows the chunk ID + uint32_t chunkSize = readValue(chunkHeader, 4); + + if (matchesFourCC(chunkHeader.data(), chunkHeader.size(), 0, chunkId)) { + return {startPos + CHUNK_HEADER_SIZE, chunkSize}; + } + // Skip this chunk and continue searching (odd chunks are padded) + uint64_t numBytesToSkip = + CHUNK_HEADER_SIZE + static_cast(chunkSize) + (chunkSize % 2); + STD_TORCH_CHECK( + startPos <= UINT64_MAX - numBytesToSkip, + "File position arithmetic would overflow"); + startPos += numBytesToSkip; + } + STD_TORCH_CHECK(false, "Chunk not found: ", chunkId); +} + +StreamMetadata WavDecoder::getStreamMetadata() const { + StreamMetadata metadata; + metadata.streamIndex = 0; // WAV files have single audio stream + metadata.sampleRate = static_cast(header_.sampleRate); + metadata.numChannels = static_cast(header_.numChannels); + metadata.sampleFormat = sampleFormat_; + metadata.codecName = codecName_; + + // Calculate duration from data size + double bitRate = static_cast(header_.sampleRate) * + static_cast(header_.numChannels) * + static_cast(header_.bitsPerSample); + metadata.bitRate = bitRate; + metadata.durationSecondsFromHeader = + static_cast(header_.dataSize) * 8 / bitRate; + metadata.beginStreamPtsSecondsFromContent = 0.0; + + return metadata; +} +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/WavDecoder.h b/src/torchcodec/_core/WavDecoder.h new file mode 100644 index 000000000..b940b1801 --- /dev/null +++ b/src/torchcodec/_core/WavDecoder.h @@ -0,0 +1,62 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include +#include +#include "Metadata.h" +#include "StableABICompat.h" + +namespace facebook::torchcodec { + +class WavDecoder { + public: + explicit WavDecoder(const std::string& path); + // Delete copy constructor and copy assignment operator since std::ifstream + // is stored as a member variable and is not copyable. + WavDecoder(const WavDecoder&) = delete; + WavDecoder& operator=(const WavDecoder&) = delete; + WavDecoder(WavDecoder&&) noexcept = default; + WavDecoder& operator=(WavDecoder&&) noexcept = default; + ~WavDecoder() = default; + + StreamMetadata getStreamMetadata() const; + + private: + struct WavHeader { + uint16_t audioFormat = 0; + uint16_t numChannels = 0; + uint32_t sampleRate = 0; + uint16_t bitsPerSample = 0; + // Extended format fields (WAVE_FORMAT_EXTENSIBLE) + uint16_t subFormat = 0; // Extracted from SubFormat GUID (first 2 bytes) + uint32_t dataSize = 0; // Size of audio data in bytes + }; + + struct ChunkInfo { + uint64_t offset; + uint32_t size; + + ChunkInfo(uint64_t offset, uint32_t size) : offset(offset), size(size) {} + }; + + ChunkInfo findChunk( + std::string_view chunkId, + uint64_t startPos, + uint64_t fileSizeLimit); + void parseHeader(uint64_t actualFileSize); + void validateHeader(); + + std::ifstream file_; + WavHeader header_; + std::string sampleFormat_; + std::string codecName_; +}; + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/__init__.py b/src/torchcodec/_core/__init__.py index 55ff697b3..2c97e785a 100644 --- a/src/torchcodec/_core/__init__.py +++ b/src/torchcodec/_core/__init__.py @@ -13,22 +13,20 @@ VideoStreamMetadata, ) from .ops import ( - _add_video_stream, _get_backend_details, _get_key_frame_indices, + _get_nvdec_cache_size, _test_frame_pts_equality, - add_audio_stream, - add_video_stream, - create_from_bytes, - create_from_file, - create_from_file_like, - create_from_tensor, + core_library_path, + create_streaming_encoder_to_file, + create_wav_decoder_from_file, encode_audio_to_file, encode_audio_to_file_like, encode_audio_to_tensor, encode_video_to_file, encode_video_to_file_like, encode_video_to_tensor, + ffmpeg_major_version, get_ffmpeg_library_versions, get_frame_at_index, get_frame_at_pts, @@ -39,6 +37,11 @@ get_frames_in_range, get_json_metadata, get_next_frame, - scan_all_streams_to_update_metadata, - seek_to_pts, + get_nvdec_cache_capacity, + get_wav_all_samples, + get_wav_metadata_from_decoder, + set_nvdec_cache_capacity, + streaming_encoder_add_frames, + streaming_encoder_add_video_stream, + streaming_encoder_close, ) diff --git a/src/torchcodec/_core/_decoder_utils.py b/src/torchcodec/_core/_decoder_utils.py new file mode 100644 index 000000000..f1e6d0eba --- /dev/null +++ b/src/torchcodec/_core/_decoder_utils.py @@ -0,0 +1,196 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import io +from collections.abc import Sequence +from pathlib import Path + +from torch import nn, Tensor +from torchcodec._core._metadata import ( + AudioStreamMetadata, + get_container_metadata, + VideoStreamMetadata, +) +from torchcodec._core.ops import ( + add_audio_stream, + add_video_stream, + create_from_bytes, + create_from_file, + create_from_file_like, + create_from_tensor, + create_wav_decoder_from_file, +) +from torchcodec.transforms import DecoderTransform +from torchcodec.transforms._decoder_transforms import _make_transform_specs + + +_ERROR_REPORTING_INSTRUCTIONS = """ +This should never happen. Please report an issue following the steps in +https://github.com/pytorch/torchcodec/issues/new?assignees=&labels=&projects=&template=bug-report.yml. +""" + + +def create_decoder( + *, + source: str | Path | io.RawIOBase | io.BufferedReader | bytes | Tensor, + seek_mode: str, +) -> Tensor: + if isinstance(source, str): + return create_from_file(source, seek_mode) + elif isinstance(source, Path): + return create_from_file(str(source), seek_mode) + elif isinstance(source, io.RawIOBase) or isinstance(source, io.BufferedReader): + return create_from_file_like(source, seek_mode) + elif isinstance(source, bytes): + return create_from_bytes(source, seek_mode) + elif isinstance(source, Tensor): + return create_from_tensor(source, seek_mode) + elif isinstance(source, io.TextIOBase): + raise TypeError( + "source is for reading text, likely from open(..., 'r'). Try with 'rb' for binary reading?" + ) + elif hasattr(source, "read") and hasattr(source, "seek"): + return create_from_file_like(source, seek_mode) + + raise TypeError( + f"Unknown source type: {type(source)}. " + "Supported types are str, Path, bytes, Tensor and file-like objects with " + "read(self, size: int) -> bytes and " + "seek(self, offset: int, whence: int) -> int methods." + ) + + +def create_audio_decoder( + *, + source: str | Path | io.RawIOBase | io.BufferedReader | bytes | Tensor, + seek_mode: str, + stream_index: int | None = None, + sample_rate: int | None = None, + num_channels: int | None = None, +) -> tuple[Tensor, int, AudioStreamMetadata]: + + decoder = create_decoder(source=source, seek_mode=seek_mode) + + container_metadata = get_container_metadata(decoder) + + if stream_index is None: + stream_index = container_metadata.best_audio_stream_index + if stream_index is None: + raise ValueError( + "The best audio stream is unknown and there is no specified stream. " + + _ERROR_REPORTING_INSTRUCTIONS + ) + + if stream_index >= len(container_metadata.streams): + raise ValueError(f"The stream at index {stream_index} is not a valid stream.") + + metadata = container_metadata.streams[stream_index] + if not isinstance(metadata, AudioStreamMetadata): + raise ValueError(f"The stream at index {stream_index} is not an audio stream.") + + add_audio_stream( + decoder, + stream_index=stream_index, + sample_rate=sample_rate, + num_channels=num_channels, + ) + + return (decoder, stream_index, metadata) + + +def create_wav_decoder(source: str | Path) -> Tensor: + if isinstance(source, str): + return create_wav_decoder_from_file(source) + elif isinstance(source, Path): + return create_wav_decoder_from_file(str(source)) + else: + raise ValueError( + "Source is not a supported uncompressed WAV file. " + "For compressed audio formats or non-WAV files, use AudioDecoder instead." + ) + + +def _get_and_validate_video_stream_metadata( + *, + decoder: Tensor, + stream_index: int | None = None, +) -> tuple[VideoStreamMetadata, int]: + container_metadata = get_container_metadata(decoder) + + if stream_index is None: + if (stream_index := container_metadata.best_video_stream_index) is None: + raise ValueError( + "The best video stream is unknown and there is no specified stream. " + + _ERROR_REPORTING_INSTRUCTIONS + ) + + if stream_index >= len(container_metadata.streams): + raise ValueError(f"The stream index {stream_index} is not a valid stream.") + + metadata = container_metadata.streams[stream_index] + if not isinstance(metadata, VideoStreamMetadata): + raise ValueError(f"The stream at index {stream_index} is not a video stream. ") + + if metadata.begin_stream_seconds is None: + raise ValueError( + "The minimum pts value in seconds is unknown. " + + _ERROR_REPORTING_INSTRUCTIONS + ) + + if metadata.end_stream_seconds is None: + raise ValueError( + "The maximum pts value in seconds is unknown. " + + _ERROR_REPORTING_INSTRUCTIONS + ) + + if metadata.num_frames is None: + raise ValueError( + "The number of frames is unknown. " + _ERROR_REPORTING_INSTRUCTIONS + ) + + return (metadata, stream_index) + + +def create_video_decoder( + *, + source: str | Path | io.RawIOBase | io.BufferedReader | bytes | Tensor, + seek_mode: str, + stream_index: int | None = None, + dimension_order: str = "NCHW", + num_ffmpeg_threads: int = 1, + device: str, + device_variant: str = "ffmpeg", + transforms: Sequence[DecoderTransform | nn.Module] | None = None, + custom_frame_mappings: tuple[Tensor, Tensor, Tensor] | None = None, +) -> tuple[Tensor, int, VideoStreamMetadata]: + + decoder = create_decoder(source=source, seek_mode=seek_mode) + + ( + metadata, + stream_index, + ) = _get_and_validate_video_stream_metadata( + decoder=decoder, stream_index=stream_index + ) + + transform_specs = _make_transform_specs( + transforms, + input_dims=(metadata.height, metadata.width), + ) + + add_video_stream( + decoder, + stream_index=stream_index, + dimension_order=dimension_order, + num_threads=num_ffmpeg_threads, + device=device, + device_variant=device_variant, + transform_specs=transform_specs, + custom_frame_mappings=custom_frame_mappings, + ) + + return (decoder, stream_index, metadata) diff --git a/src/torchcodec/_core/_metadata.py b/src/torchcodec/_core/_metadata.py index 1f011f516..4e31bcbd5 100644 --- a/src/torchcodec/_core/_metadata.py +++ b/src/torchcodec/_core/_metadata.py @@ -4,15 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + import dataclasses import json import pathlib from dataclasses import dataclass from fractions import Fraction -from typing import List, Optional, Union import torch - from torchcodec._core.ops import ( _get_container_json_metadata, _get_stream_json_metadata, @@ -25,19 +24,35 @@ @dataclass class StreamMetadata: - duration_seconds_from_header: Optional[float] + duration_seconds_from_header: float | None """Duration of the stream, in seconds, obtained from the header (float or None). This could be inaccurate.""" - begin_stream_seconds_from_header: Optional[float] + begin_stream_seconds_from_header: float | None """Beginning of the stream, in seconds, obtained from the header (float or None). Usually, this is equal to 0.""" - bit_rate: Optional[float] + bit_rate: float | None """Bit rate of the stream, in seconds (float or None).""" - codec: Optional[str] + codec: str | None """Codec (str or None).""" stream_index: int """Index of the stream that this metadata refers to (int).""" + # Computed fields (computed in C++ with fallback logic) + duration_seconds: float | None + """Duration of the stream in seconds. We try to calculate the duration + from the actual frames if a :term:`scan` was performed. Otherwise we + fall back to ``duration_seconds_from_header``. If that value is also None, + we instead calculate the duration from ``num_frames_from_header`` and + ``average_fps_from_header``. If all of those are unavailable, we fall back + to the container-level ``duration_seconds_from_header``. + """ + begin_stream_seconds: float | None + """Beginning of the stream, in seconds (float). Conceptually, this + corresponds to the first frame's :term:`pts`. If a :term:`scan` was performed + and ``begin_stream_seconds_from_content`` is not None, then it is returned. + Otherwise, this value is 0. + """ + def __repr__(self): s = self.__class__.__name__ + ":\n" for field in dataclasses.fields(self): @@ -49,12 +64,12 @@ def __repr__(self): class VideoStreamMetadata(StreamMetadata): """Metadata of a single video stream.""" - begin_stream_seconds_from_content: Optional[float] + begin_stream_seconds_from_content: float | None """Beginning of the stream, in seconds (float or None). Conceptually, this corresponds to the first frame's :term:`pts`. It is only computed when a :term:`scan` is done as min(frame.pts) across all frames in the stream. Usually, this is equal to 0.""" - end_stream_seconds_from_content: Optional[float] + end_stream_seconds_from_content: float | None """End of the stream, in seconds (float or None). Conceptually, this corresponds to last_frame.pts + last_frame.duration. It is only computed when a :term:`scan` is done as max(frame.pts + @@ -65,136 +80,87 @@ class VideoStreamMetadata(StreamMetadata): simply indexing the :class:`~torchcodec.decoders.VideoDecoder` object with ``[-1]``. """ - width: Optional[int] + width: int | None """Width of the frames (int or None).""" - height: Optional[int] + height: int | None """Height of the frames (int or None).""" - num_frames_from_header: Optional[int] + num_frames_from_header: int | None """Number of frames, from the stream's metadata. This is potentially inaccurate. We recommend using the ``num_frames`` attribute instead. (int or None).""" - num_frames_from_content: Optional[int] + num_frames_from_content: int | None """Number of frames computed by TorchCodec by scanning the stream's content (the scan doesn't involve decoding). This is more accurate than ``num_frames_from_header``. We recommend using the ``num_frames`` attribute instead. (int or None).""" - average_fps_from_header: Optional[float] + average_fps_from_header: float | None """Averate fps of the stream, obtained from the header (float or None). We recommend using the ``average_fps`` attribute instead.""" - pixel_aspect_ratio: Optional[Fraction] + pixel_aspect_ratio: Fraction | None """Pixel Aspect Ratio (PAR), also known as Sample Aspect Ratio (SAR --- not to be confused with Storage Aspect Ratio, also SAR), is the ratio between the width and height of each pixel (``fractions.Fraction`` or None).""" - - @property - def duration_seconds(self) -> Optional[float]: - """Duration of the stream in seconds. We try to calculate the duration - from the actual frames if a :term:`scan` was performed. Otherwise we - fall back to ``duration_seconds_from_header``. If that value is also None, - we instead calculate the duration from ``num_frames_from_header`` and - ``average_fps_from_header``. - """ - if ( - self.end_stream_seconds_from_content is not None - and self.begin_stream_seconds_from_content is not None - ): - return ( - self.end_stream_seconds_from_content - - self.begin_stream_seconds_from_content - ) - elif self.duration_seconds_from_header is not None: - return self.duration_seconds_from_header - elif ( - self.num_frames_from_header is not None - and self.average_fps_from_header is not None - ): - return self.num_frames_from_header / self.average_fps_from_header - else: - return None - - @property - def begin_stream_seconds(self) -> float: - """Beginning of the stream, in seconds (float). Conceptually, this - corresponds to the first frame's :term:`pts`. If - ``begin_stream_seconds_from_content`` is not None, then it is returned. - Otherwise, this value is 0. - """ - if self.begin_stream_seconds_from_content is None: - return 0 - else: - return self.begin_stream_seconds_from_content - - @property - def end_stream_seconds(self) -> Optional[float]: - """End of the stream, in seconds (float or None). - Conceptually, this corresponds to last_frame.pts + last_frame.duration. - If ``end_stream_seconds_from_content`` is not None, then that value is - returned. Otherwise, returns ``duration_seconds``. - """ - if self.end_stream_seconds_from_content is None: - return self.duration_seconds - else: - return self.end_stream_seconds_from_content - - @property - def num_frames(self) -> Optional[int]: - """Number of frames in the stream (int or None). - This corresponds to ``num_frames_from_content`` if a :term:`scan` was made, - otherwise it corresponds to ``num_frames_from_header``. If that value is also - None, the number of frames is calculated from the duration and the average fps. - """ - if self.num_frames_from_content is not None: - return self.num_frames_from_content - elif self.num_frames_from_header is not None: - return self.num_frames_from_header - elif ( - self.average_fps_from_header is not None - and self.duration_seconds_from_header is not None - ): - return int(self.average_fps_from_header * self.duration_seconds_from_header) - else: - return None - - @property - def average_fps(self) -> Optional[float]: - """Average fps of the stream. If a :term:`scan` was perfomed, this is - computed from the number of frames and the duration of the stream. - Otherwise we fall back to ``average_fps_from_header``. - """ - if ( - self.end_stream_seconds_from_content is None - or self.begin_stream_seconds_from_content is None - or self.num_frames is None - # Should never happen, but prevents ZeroDivisionError: - or self.end_stream_seconds_from_content - == self.begin_stream_seconds_from_content - ): - return self.average_fps_from_header - return self.num_frames / ( - self.end_stream_seconds_from_content - - self.begin_stream_seconds_from_content - ) + rotation: float | None + """Rotation angle in degrees (counter-clockwise rounded to the nearest + multiple of 90 degrees) from the display matrix metadata. This indicates + how the video should be rotated for correct display. TorchCodec automatically + applies this rotation during decoding, so the returned frames are in the + correct orientation (float or None). + + .. note:: + + The :attr:`~torchcodec.decoders.VideoStreamMetadata.width` and + :attr:`~torchcodec.decoders.VideoStreamMetadata.height` attributes report + the **post-rotation** dimensions, i.e., the dimensions of frames as they + will be returned by TorchCodec's decoding methods. For videos with 90 + or -90 degree rotation, this means width and height are swapped + compared to the raw encoded dimensions in the container. + """ + color_primaries: str | None + """Color primaries as reported by FFmpeg. E.g. ``"bt709"``, ``"bt2020"``.""" + color_space: str | None + """Color space as reported by FFmpeg. E.g. ``"bt709"``, + ``"bt2020nc"``.""" + color_transfer_characteristic: str | None + """Color transfer characteristic as reported by FFmpeg + E.g. ``"bt709"``, ``"smpte2084"`` (PQ), ``"arib-std-b67"`` (HLG).""" + pixel_format: str | None + """The source pixel format of the video as reported by FFmpeg. + E.g. ``'yuv420p'``, ``'yuv444p'``, etc.""" + + # Computed fields (computed in C++ with fallback logic) + end_stream_seconds: float | None + """End of the stream, in seconds (float or None). + Conceptually, this corresponds to last_frame.pts + last_frame.duration. + If :term:`scan` was performed and``end_stream_seconds_from_content`` is not None, then that value is + returned. Otherwise, returns ``duration_seconds``. + """ + num_frames: int | None + """Number of frames in the stream (int or None). + This corresponds to ``num_frames_from_content`` if a :term:`scan` was made, + otherwise it corresponds to ``num_frames_from_header``. If that value is also + None, the number of frames is calculated from the duration and the average fps. + """ + average_fps: float | None + """Average fps of the stream. If a :term:`scan` was perfomed, this is + computed from the number of frames and the duration of the stream. + Otherwise we fall back to ``average_fps_from_header``. + """ def __repr__(self): - s = super().__repr__() - s += f"{SPACES}duration_seconds: {self.duration_seconds}\n" - s += f"{SPACES}begin_stream_seconds: {self.begin_stream_seconds}\n" - s += f"{SPACES}end_stream_seconds: {self.end_stream_seconds}\n" - s += f"{SPACES}num_frames: {self.num_frames}\n" - s += f"{SPACES}average_fps: {self.average_fps}\n" - return s + return super().__repr__() @dataclass class AudioStreamMetadata(StreamMetadata): """Metadata of a single audio stream.""" - sample_rate: Optional[int] + sample_rate: int | None """The original sample rate.""" - num_channels: Optional[int] + num_channels: int | None """The number of channels (1 for mono, 2 for stereo, etc.)""" - sample_format: Optional[str] + sample_format: str | None """The original sample format, as described by FFmpeg. E.g. 'fltp', 's32', etc.""" def __repr__(self): @@ -203,19 +169,19 @@ def __repr__(self): @dataclass class ContainerMetadata: - duration_seconds_from_header: Optional[float] - bit_rate_from_header: Optional[float] - best_video_stream_index: Optional[int] - best_audio_stream_index: Optional[int] + duration_seconds_from_header: float | None + bit_rate_from_header: float | None + best_video_stream_index: int | None + best_audio_stream_index: int | None - streams: List[StreamMetadata] + streams: list[StreamMetadata] @property - def duration_seconds(self) -> Optional[float]: + def duration_seconds(self) -> float | None: raise NotImplementedError("Decide on logic and implement this!") @property - def bit_rate(self) -> Optional[float]: + def bit_rate(self) -> float | None: raise NotImplementedError("Decide on logic and implement this!") @property @@ -255,15 +221,17 @@ def get_container_metadata(decoder: torch.Tensor) -> ContainerMetadata: """ container_dict = json.loads(_get_container_json_metadata(decoder)) - streams_metadata: List[StreamMetadata] = [] + streams_metadata: list[StreamMetadata] = [] for stream_index in range(container_dict["numStreams"]): stream_dict = json.loads(_get_stream_json_metadata(decoder, stream_index)) common_meta = dict( duration_seconds_from_header=stream_dict.get("durationSecondsFromHeader"), + duration_seconds=stream_dict.get("durationSeconds"), bit_rate=stream_dict.get("bitRate"), begin_stream_seconds_from_header=stream_dict.get( "beginStreamSecondsFromHeader" ), + begin_stream_seconds=stream_dict.get("beginStreamSeconds"), codec=stream_dict.get("codec"), stream_index=stream_index, ) @@ -276,12 +244,22 @@ def get_container_metadata(decoder: torch.Tensor) -> ContainerMetadata: end_stream_seconds_from_content=stream_dict.get( "endStreamSecondsFromContent" ), + end_stream_seconds=stream_dict.get("endStreamSeconds"), + num_frames=stream_dict.get("numFrames"), + average_fps=stream_dict.get("averageFps"), width=stream_dict.get("width"), height=stream_dict.get("height"), num_frames_from_header=stream_dict.get("numFramesFromHeader"), num_frames_from_content=stream_dict.get("numFramesFromContent"), average_fps_from_header=stream_dict.get("averageFpsFromHeader"), pixel_aspect_ratio=_get_optional_par_fraction(stream_dict), + rotation=stream_dict.get("rotation"), + color_primaries=stream_dict.get("colorPrimaries"), + color_space=stream_dict.get("colorSpace"), + color_transfer_characteristic=stream_dict.get( + "colorTransferCharacteristic" + ), + pixel_format=stream_dict.get("pixelFormat"), **common_meta, ) ) @@ -310,7 +288,7 @@ def get_container_metadata(decoder: torch.Tensor) -> ContainerMetadata: def get_container_metadata_from_header( - filename: Union[str, pathlib.Path] + filename: str | pathlib.Path, ) -> ContainerMetadata: return get_container_metadata( create_from_file(str(filename), seek_mode="approximate") diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 13ad3be35..68a16cbbd 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -4,9 +4,9 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. +#include #include #include -#include #include #include "c10/core/SymIntArrayRef.h" #include "c10/util/Exception.h" @@ -38,19 +38,19 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()"); m.def( - "encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()"); + "encode_video_to_file(Tensor frames, float frame_rate, str filename, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()"); m.def( - "encode_video_to_tensor(Tensor frames, int frame_rate, str format, int? crf=None) -> Tensor"); + "encode_video_to_tensor(Tensor frames, float frame_rate, str format, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> Tensor"); m.def( - "_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, int? crf=None) -> ()"); + "_encode_video_to_file_like(Tensor frames, float frame_rate, str format, int file_like_context, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()"); m.def( "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor"); m.def( "_create_from_file_like(int file_like_context, str? seek_mode=None) -> Tensor"); m.def( - "_add_video_stream(Tensor(a!) decoder, *, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\"cpu\", str device_variant=\"ffmpeg\", str transform_specs=\"\", (Tensor, Tensor, Tensor)? custom_frame_mappings=None, str? color_conversion_library=None) -> ()"); + "_add_video_stream(Tensor(a!) decoder, *, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\"cpu\", str device_variant=\"ffmpeg\", str transform_specs=\"\", Tensor? custom_frame_mappings_pts=None, Tensor? custom_frame_mappings_duration=None, Tensor? custom_frame_mappings_keyframe_indices=None, str? color_conversion_library=None) -> ()"); m.def( - "add_video_stream(Tensor(a!) decoder, *, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\"cpu\", str device_variant=\"ffmpeg\", str transform_specs=\"\", (Tensor, Tensor, Tensor)? custom_frame_mappings=None) -> ()"); + "add_video_stream(Tensor(a!) decoder, *, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\"cpu\", str device_variant=\"ffmpeg\", str transform_specs=\"\", Tensor? custom_frame_mappings_pts=None, Tensor? custom_frame_mappings_duration=None, Tensor? custom_frame_mappings_keyframe_indices=None) -> ()"); m.def( "add_audio_stream(Tensor(a!) decoder, *, int? stream_index=None, int? sample_rate=None, int? num_channels=None) -> ()"); m.def("seek_to_pts(Tensor(a!) decoder, float seconds) -> ()"); @@ -64,7 +64,7 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "get_frames_in_range(Tensor(a!) decoder, *, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)"); m.def( - "get_frames_by_pts_in_range(Tensor(a!) decoder, *, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)"); + "get_frames_by_pts_in_range(Tensor(a!) decoder, *, float start_seconds, float stop_seconds, float? fps=None) -> (Tensor, Tensor, Tensor)"); m.def( "get_frames_by_pts_in_range_audio(Tensor(a!) decoder, *, float start_seconds, float? stop_seconds) -> (Tensor, Tensor)"); m.def( @@ -79,25 +79,81 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "_test_frame_pts_equality(Tensor(a!) decoder, *, int frame_index, float pts_seconds_to_test) -> bool"); m.def("scan_all_streams_to_update_metadata(Tensor(a!) decoder) -> ()"); + m.def("create_streaming_encoder_to_file(str filename) -> Tensor"); + m.def("streaming_encoder_close(Tensor(a!) encoder) -> ()"); + m.def( + "streaming_encoder_add_video_stream(Tensor(a!) encoder, float frame_rate, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()"); + m.def( + "streaming_encoder_add_frames(Tensor(a!) encoder, Tensor frames) -> ()"); + m.def("set_nvdec_cache_capacity(int capacity) -> ()"); + m.def("get_nvdec_cache_capacity() -> int"); + m.def("_get_nvdec_cache_size(int device_index) -> int"); + m.def("create_wav_decoder_from_file(str filename) -> Tensor"); + m.def("get_wav_all_samples(Tensor decoder) -> Tensor"); + m.def("get_wav_metadata_from_decoder(Tensor(a!) decoder) -> str"); } namespace { -at::Tensor wrapDecoderPointerToTensor( +// TODO_STABLE_ABI: use previous deleter pattern with a lambda, once +// https://github.com/pytorch/pytorch/pull/175089 is available. +void decoderDeleter(void* data) { + delete static_cast(data); +} + +torch::stable::Tensor wrapDecoderPointerToTensor( std::unique_ptr uniqueDecoder) { SingleStreamDecoder* decoder = uniqueDecoder.release(); - auto deleter = [decoder](void*) { delete decoder; }; - at::Tensor tensor = at::from_blob( - decoder, {sizeof(SingleStreamDecoder*)}, deleter, {at::kLong}); + int64_t sizes[] = {static_cast(sizeof(SingleStreamDecoder*))}; + int64_t strides[] = {1}; + torch::stable::Tensor tensor = torch::stable::from_blob( + decoder, + {sizes, 1}, + {strides, 1}, + StableDevice(kStableCPU), + kStableInt64, + &decoderDeleter); auto videoDecoder = static_cast(tensor.mutable_data_ptr()); // TORCH_CHECK_EQ(videoDecoder, decoder) << "videoDecoder=" << videoDecoder; return tensor; } -SingleStreamDecoder* unwrapTensorToGetDecoder(at::Tensor& tensor) { - TORCH_CHECK( +// TODO_STABLE_ABI: use previous deleter pattern with a lambda, once +// https://github.com/pytorch/pytorch/pull/175089 is available. +void wavDecoderDeleter(void* data) { + delete static_cast(data); +} + +torch::stable::Tensor wrapWavDecoderPointerToTensor( + std::unique_ptr uniqueDecoder) { + WavDecoder* decoder = uniqueDecoder.release(); + + int64_t sizes[] = {static_cast(sizeof(WavDecoder*))}; + int64_t strides[] = {1}; + torch::stable::Tensor tensor = torch::stable::from_blob( + decoder, + {sizes, 1}, + {strides, 1}, + StableDevice(kStableCPU), + kStableInt64, + &wavDecoderDeleter); + auto wavDecoder = static_cast(tensor.mutable_data_ptr()); + STD_TORCH_CHECK(wavDecoder == decoder, "wavDecoder != decoder"); + return tensor; +} + +WavDecoder* unwrapTensorToGetWavDecoder(torch::stable::Tensor& tensor) { + STD_TORCH_CHECK( + tensor.is_contiguous(), + "fake decoder tensor must be contiguous! This is an internal error, please report on the torchcodec issue tracker."); + WavDecoder* decoder = static_cast(tensor.mutable_data_ptr()); + return decoder; +} + +SingleStreamDecoder* unwrapTensorToGetDecoder(torch::stable::Tensor& tensor) { + STD_TORCH_CHECK( tensor.is_contiguous(), "fake decoder tensor must be contiguous! This is an internal error, please report on the torchcodec issue tracker."); void* buffer = tensor.mutable_data_ptr(); @@ -105,13 +161,49 @@ SingleStreamDecoder* unwrapTensorToGetDecoder(at::Tensor& tensor) { return decoder; } +// TODO_STABLE_ABI: use previous deleter pattern with a lambda, once +// https://github.com/pytorch/pytorch/pull/175089 is available. +void multiStreamEncoderDeleter(void* data) { + delete static_cast(data); +} + +torch::stable::Tensor wrapMultiStreamEncoderPointerToTensor( + std::unique_ptr uniqueEncoder) { + MultiStreamEncoder* encoder = uniqueEncoder.release(); + int64_t sizes[] = {static_cast(sizeof(MultiStreamEncoder*))}; + int64_t strides[] = {1}; + torch::stable::Tensor tensor = torch::stable::from_blob( + encoder, + {sizes, 1}, + {strides, 1}, + StableDevice(kStableCPU), + kStableInt64, + &multiStreamEncoderDeleter); + auto multiStreamEncoder = + static_cast(tensor.mutable_data_ptr()); + STD_TORCH_CHECK( + multiStreamEncoder == encoder, "multiStreamEncoder != encoder"); + return tensor; +} + +MultiStreamEncoder* unwrapTensorToGetMultiStreamEncoder( + torch::stable::Tensor& tensor) { + STD_TORCH_CHECK( + tensor.is_contiguous(), + "fake encoder tensor must be contiguous! This is an internal error, please report on the torchcodec issue tracker."); + void* buffer = tensor.mutable_data_ptr(); + MultiStreamEncoder* encoder = static_cast(buffer); + return encoder; +} + // The elements of this tuple are all tensors that represent a single frame: // 1. The frame data, which is a multidimensional tensor. // 2. A single float value for the pts in seconds. // 3. A single float value for the duration in seconds. // The reason we use Tensors for the second and third values is so we can run // under torch.compile(). -using OpsFrameOutput = std::tuple; +using OpsFrameOutput = std:: + tuple; OpsFrameOutput makeOpsFrameOutput(FrameOutput& frame) { // return std::make_tuple( @@ -141,7 +233,8 @@ SingleStreamDecoder::FrameMappings makeFrameMappings( // float. // 3. Tensor of N durationis in seconds, where each duration is a // single float. -using OpsFrameBatchOutput = std::tuple; +using OpsFrameBatchOutput = std:: + tuple; OpsFrameBatchOutput makeOpsFrameBatchOutput(FrameBatchOutput& batch) { return std::make_tuple(batch.data, batch.ptsSeconds, batch.durationSeconds); @@ -151,7 +244,8 @@ OpsFrameBatchOutput makeOpsFrameBatchOutput(FrameBatchOutput& batch) { // of multiple audio frames: // 1. The frames data (concatenated) // 2. A single float value for the pts of the first frame, in seconds. -using OpsAudioFramesOutput = std::tuple; +using OpsAudioFramesOutput = + std::tuple; OpsAudioFramesOutput makeOpsAudioFramesOutput(AudioFramesOutput& audioFrames) { // return std::make_tuple( @@ -166,6 +260,16 @@ std::string quoteValue(const std::string& value) { return "\"" + value + "\""; } +// Helper function to unflatten extra_options, alternating keys and values +std::map unflattenExtraOptions( + const std::vector& opts) { + std::map optionsMap; + for (size_t i = 0; i < opts.size(); i += 2) { + optionsMap[opts[i]] = opts[i + 1]; + } + return optionsMap; +} + std::string mapToJson(const std::map& metadataMap) { std::stringstream ss; ss << "{\n"; @@ -184,15 +288,43 @@ std::string mapToJson(const std::map& metadataMap) { return ss.str(); } -SingleStreamDecoder::SeekMode seekModeFromString(std::string_view seekMode) { +SeekMode seekModeFromString(std::string_view seekMode) { if (seekMode == "exact") { - return SingleStreamDecoder::SeekMode::exact; + return SeekMode::exact; } else if (seekMode == "approximate") { - return SingleStreamDecoder::SeekMode::approximate; + return SeekMode::approximate; } else if (seekMode == "custom_frame_mappings") { - return SingleStreamDecoder::SeekMode::custom_frame_mappings; + return SeekMode::custom_frame_mappings; } else { - TORCH_CHECK(false, "Invalid seek mode: " + std::string(seekMode)); + STD_TORCH_CHECK(false, "Invalid seek mode: " + std::string(seekMode)); + } +} + +void writeFallbackBasedMetadata( + std::map& map, + const StreamMetadata& streamMetadata, + SeekMode seekMode) { + auto durationSeconds = streamMetadata.getDurationSeconds(seekMode); + if (durationSeconds.has_value()) { + map["durationSeconds"] = fmt::to_string(durationSeconds.value()); + } + + auto numFrames = streamMetadata.getNumFrames(seekMode); + if (numFrames.has_value()) { + map["numFrames"] = std::to_string(numFrames.value()); + } + + double beginStreamSeconds = streamMetadata.getBeginStreamSeconds(seekMode); + map["beginStreamSeconds"] = fmt::to_string(beginStreamSeconds); + + auto endStreamSeconds = streamMetadata.getEndStreamSeconds(seekMode); + if (endStreamSeconds.has_value()) { + map["endStreamSeconds"] = fmt::to_string(endStreamSeconds.value()); + } + + auto averageFps = streamMetadata.getAverageFps(seekMode); + if (averageFps.has_value()) { + map["averageFps"] = fmt::to_string(averageFps.value()); } } @@ -201,11 +333,24 @@ int checkedToPositiveInt(const std::string& str) { try { ret = std::stoi(str); } catch (const std::invalid_argument&) { - TORCH_CHECK(false, "String cannot be converted to an int:" + str); + STD_TORCH_CHECK(false, "String cannot be converted to an int:" + str); } catch (const std::out_of_range&) { - TORCH_CHECK(false, "String would become integer out of range:" + str); + STD_TORCH_CHECK(false, "String would become integer out of range:" + str); } - TORCH_CHECK(ret > 0, "String must be a positive integer:" + str); + STD_TORCH_CHECK(ret > 0, "String must be a positive integer:" + str); + return ret; +} + +int checkedToNonNegativeInt(const std::string& str) { + int ret = 0; + try { + ret = std::stoi(str); + } catch (const std::invalid_argument&) { + STD_TORCH_CHECK(false, "String cannot be converted to an int:" + str); + } catch (const std::out_of_range&) { + STD_TORCH_CHECK(false, "String would become integer out of range:" + str); + } + STD_TORCH_CHECK(ret >= 0, "String must be a non-negative integer:" + str); return ret; } @@ -217,7 +362,7 @@ int checkedToPositiveInt(const std::string& str) { // integers. Transform* makeResizeTransform( const std::vector& resizeTransformSpec) { - TORCH_CHECK( + STD_TORCH_CHECK( resizeTransformSpec.size() == 3, "resizeTransformSpec must have 3 elements including its name"); int height = checkedToPositiveInt(resizeTransformSpec[1]); @@ -235,16 +380,33 @@ Transform* makeResizeTransform( // width) for specifying image dimensions; FFmpeg uses (width, height). Transform* makeCropTransform( const std::vector& cropTransformSpec) { - TORCH_CHECK( + STD_TORCH_CHECK( cropTransformSpec.size() == 5, "cropTransformSpec must have 5 elements including its name"); int height = checkedToPositiveInt(cropTransformSpec[1]); int width = checkedToPositiveInt(cropTransformSpec[2]); - int x = checkedToPositiveInt(cropTransformSpec[3]); - int y = checkedToPositiveInt(cropTransformSpec[4]); + int x = checkedToNonNegativeInt(cropTransformSpec[3]); + int y = checkedToNonNegativeInt(cropTransformSpec[4]); return new CropTransform(FrameDims(height, width), x, y); } +// CenterCrop transform specs take the form: +// +// "center_crop, , " +// +// Where "center_crop" is the string literal and , are +// positive integers. Note that we follow the PyTorch convention of (height, +// width) for specifying image dimensions; FFmpeg uses (width, height). +Transform* makeCenterCropTransform( + const std::vector& cropTransformSpec) { + STD_TORCH_CHECK( + cropTransformSpec.size() == 3, + "cropTransformSpec must have 3 elements including its name"); + int height = checkedToPositiveInt(cropTransformSpec[1]); + int width = checkedToPositiveInt(cropTransformSpec[2]); + return new CropTransform(FrameDims(height, width)); +} + std::vector split(const std::string& str, char delimiter) { std::vector tokens; std::string token; @@ -265,7 +427,7 @@ std::vector makeTransforms(const std::string& transformSpecsRaw) { std::vector transformSpecs = split(transformSpecsRaw, ';'); for (const std::string& transformSpecRaw : transformSpecs) { std::vector transformSpec = split(transformSpecRaw, ','); - TORCH_CHECK( + STD_TORCH_CHECK( transformSpec.size() >= 1, "Invalid transform spec: " + transformSpecRaw); @@ -274,8 +436,10 @@ std::vector makeTransforms(const std::string& transformSpecsRaw) { transforms.push_back(makeResizeTransform(transformSpec)); } else if (name == "crop") { transforms.push_back(makeCropTransform(transformSpec)); + } else if (name == "center_crop") { + transforms.push_back(makeCenterCropTransform(transformSpec)); } else { - TORCH_CHECK(false, "Invalid transform name: " + name); + STD_TORCH_CHECK(false, "Invalid transform name: " + name); } } return transforms; @@ -288,33 +452,32 @@ std::vector makeTransforms(const std::string& transformSpecsRaw) { // ============================== // Create a SingleStreamDecoder from file and wrap the pointer in a tensor. -at::Tensor create_from_file( - std::string_view filename, - std::optional seek_mode = std::nullopt) { - std::string filenameStr(filename); - - SingleStreamDecoder::SeekMode realSeek = SingleStreamDecoder::SeekMode::exact; +torch::stable::Tensor create_from_file( + std::string filename, + std::optional seek_mode = std::nullopt) { + SeekMode realSeek = SeekMode::exact; if (seek_mode.has_value()) { realSeek = seekModeFromString(seek_mode.value()); } std::unique_ptr uniqueDecoder = - std::make_unique(filenameStr, realSeek); + std::make_unique(filename, realSeek); return wrapDecoderPointerToTensor(std::move(uniqueDecoder)); } // Create a SingleStreamDecoder from the actual bytes of a video and wrap the // pointer in a tensor. The SingleStreamDecoder will decode the provided bytes. -at::Tensor create_from_tensor( - at::Tensor video_tensor, - std::optional seek_mode = std::nullopt) { - TORCH_CHECK(video_tensor.is_contiguous(), "video_tensor must be contiguous"); - TORCH_CHECK( - video_tensor.scalar_type() == torch::kUInt8, +torch::stable::Tensor create_from_tensor( + const torch::stable::Tensor& video_tensor, + std::optional seek_mode = std::nullopt) { + STD_TORCH_CHECK( + video_tensor.is_contiguous(), "video_tensor must be contiguous"); + STD_TORCH_CHECK( + video_tensor.scalar_type() == kStableUInt8, "video_tensor must be kUInt8"); - SingleStreamDecoder::SeekMode realSeek = SingleStreamDecoder::SeekMode::exact; + SeekMode realSeek = SeekMode::exact; if (seek_mode.has_value()) { realSeek = seekModeFromString(seek_mode.value()); } @@ -328,16 +491,16 @@ at::Tensor create_from_tensor( return wrapDecoderPointerToTensor(std::move(uniqueDecoder)); } -at::Tensor _create_from_file_like( +torch::stable::Tensor _create_from_file_like( int64_t file_like_context, - std::optional seek_mode) { + std::optional seek_mode) { auto fileLikeContext = reinterpret_cast(file_like_context); - TORCH_CHECK( + STD_TORCH_CHECK( fileLikeContext != nullptr, "file_like_context must be a valid pointer"); std::unique_ptr avioContextHolder(fileLikeContext); - SingleStreamDecoder::SeekMode realSeek = SingleStreamDecoder::SeekMode::exact; + SeekMode realSeek = SeekMode::exact; if (seek_mode.has_value()) { realSeek = seekModeFromString(seek_mode.value()); } @@ -349,26 +512,32 @@ at::Tensor _create_from_file_like( } void _add_video_stream( - at::Tensor& decoder, + torch::stable::Tensor& decoder, std::optional num_threads = std::nullopt, - std::optional dimension_order = std::nullopt, + std::optional dimension_order = std::nullopt, std::optional stream_index = std::nullopt, - std::string_view device = "cpu", - std::string_view device_variant = "ffmpeg", - std::string_view transform_specs = "", - std::optional> - custom_frame_mappings = std::nullopt, - std::optional color_conversion_library = std::nullopt) { + std::string device = "cpu", + std::string device_variant = "ffmpeg", + std::string transform_specs = "", + std::optional custom_frame_mappings_pts = + std::nullopt, + std::optional custom_frame_mappings_duration = + std::nullopt, + std::optional + custom_frame_mappings_keyframe_indices = std::nullopt, + std::optional color_conversion_library = std::nullopt) { VideoStreamOptions videoStreamOptions; videoStreamOptions.ffmpegThreadCount = num_threads; if (dimension_order.has_value()) { - std::string stdDimensionOrder{dimension_order.value()}; - TORCH_CHECK(stdDimensionOrder == "NHWC" || stdDimensionOrder == "NCHW"); - videoStreamOptions.dimensionOrder = stdDimensionOrder; + STD_TORCH_CHECK( + *dimension_order == "NHWC" || *dimension_order == "NCHW", + "dimension_order must be NHWC or NCHW"); + videoStreamOptions.dimensionOrder = std::move(*dimension_order); } if (color_conversion_library.has_value()) { - std::string stdColorConversionLibrary{color_conversion_library.value()}; + const std::string& stdColorConversionLibrary = + color_conversion_library.value(); if (stdColorConversionLibrary == "filtergraph") { videoStreamOptions.colorConversionLibrary = ColorConversionLibrary::FILTERGRAPH; @@ -376,7 +545,7 @@ void _add_video_stream( videoStreamOptions.colorConversionLibrary = ColorConversionLibrary::SWSCALE; } else { - TORCH_CHECK( + STD_TORCH_CHECK( false, "Invalid color_conversion_library=", stdColorConversionLibrary, @@ -384,17 +553,28 @@ void _add_video_stream( } } - validateDeviceInterface(std::string(device), std::string(device_variant)); + validateDeviceInterface(device, device_variant); - videoStreamOptions.device = torch::Device(std::string(device)); - videoStreamOptions.deviceVariant = device_variant; + videoStreamOptions.device = StableDevice(std::move(device)); + videoStreamOptions.deviceVariant = std::move(device_variant); std::vector transforms = - makeTransforms(std::string(transform_specs)); - - std::optional converted_mappings = - custom_frame_mappings.has_value() - ? std::make_optional(makeFrameMappings(custom_frame_mappings.value())) + makeTransforms(std::move(transform_specs)); + + bool hasPts = custom_frame_mappings_pts.has_value(); + bool hasDuration = custom_frame_mappings_duration.has_value(); + bool hasKeyframeIndices = custom_frame_mappings_keyframe_indices.has_value(); + STD_TORCH_CHECK( + (hasPts == hasDuration) && (hasDuration == hasKeyframeIndices), + "custom_frame_mappings_pts, custom_frame_mappings_duration, and " + "custom_frame_mappings_keyframe_indices must all be provided or all be " + "None. This is a bug in TorchCodec, please report it."); + + std::optional converted_mappings = hasPts + ? std::make_optional(SingleStreamDecoder::FrameMappings{ + std::move(*custom_frame_mappings_pts), + std::move(*custom_frame_mappings_keyframe_indices), + std::move(*custom_frame_mappings_duration)}) : std::nullopt; auto videoDecoder = unwrapTensorToGetDecoder(decoder); videoDecoder->addVideoStream( @@ -406,28 +586,34 @@ void _add_video_stream( // Add a new video stream at `stream_index` using the provided options. void add_video_stream( - at::Tensor& decoder, + torch::stable::Tensor& decoder, std::optional num_threads = std::nullopt, - std::optional dimension_order = std::nullopt, + std::optional dimension_order = std::nullopt, std::optional stream_index = std::nullopt, - std::string_view device = "cpu", - std::string_view device_variant = "ffmpeg", - std::string_view transform_specs = "", - const std::optional>& - custom_frame_mappings = std::nullopt) { + std::string device = "cpu", + std::string device_variant = "ffmpeg", + std::string transform_specs = "", + std::optional custom_frame_mappings_pts = + std::nullopt, + std::optional custom_frame_mappings_duration = + std::nullopt, + std::optional + custom_frame_mappings_keyframe_indices = std::nullopt) { _add_video_stream( decoder, num_threads, - dimension_order, + std::move(dimension_order), stream_index, - device, - device_variant, - transform_specs, - custom_frame_mappings); + std::move(device), + std::move(device_variant), + std::move(transform_specs), + std::move(custom_frame_mappings_pts), + std::move(custom_frame_mappings_duration), + std::move(custom_frame_mappings_keyframe_indices)); } void add_audio_stream( - at::Tensor& decoder, + torch::stable::Tensor& decoder, std::optional stream_index = std::nullopt, std::optional sample_rate = std::nullopt, std::optional num_channels = std::nullopt) { @@ -440,7 +626,7 @@ void add_audio_stream( } // Seek to a particular presentation timestamp in the video in seconds. -void seek_to_pts(at::Tensor& decoder, double seconds) { +void seek_to_pts(torch::stable::Tensor& decoder, double seconds) { auto videoDecoder = static_cast(decoder.mutable_data_ptr()); videoDecoder->setCursorPtsInSeconds(seconds); @@ -448,7 +634,7 @@ void seek_to_pts(at::Tensor& decoder, double seconds) { // Get the next frame from the video as a tuple that has the frame data, pts and // duration as tensors. -OpsFrameOutput get_next_frame(at::Tensor& decoder) { +OpsFrameOutput get_next_frame(torch::stable::Tensor& decoder) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); FrameOutput result; try { @@ -463,7 +649,9 @@ OpsFrameOutput get_next_frame(at::Tensor& decoder) { // Return the frame that is visible at a given timestamp in seconds. Each frame // in FFMPEG has a presentation timestamp and a duration. The frame visible at a // given timestamp T has T >= PTS and T < PTS + Duration. -OpsFrameOutput get_frame_at_pts(at::Tensor& decoder, double seconds) { +OpsFrameOutput get_frame_at_pts( + torch::stable::Tensor& decoder, + double seconds) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); FrameOutput result; try { @@ -476,7 +664,9 @@ OpsFrameOutput get_frame_at_pts(at::Tensor& decoder, double seconds) { } // Return the frame that is visible at a given index in the video. -OpsFrameOutput get_frame_at_index(at::Tensor& decoder, int64_t frame_index) { +OpsFrameOutput get_frame_at_index( + torch::stable::Tensor& decoder, + int64_t frame_index) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); auto result = videoDecoder->getFrameAtIndex(frame_index); return makeOpsFrameOutput(result); @@ -484,8 +674,8 @@ OpsFrameOutput get_frame_at_index(at::Tensor& decoder, int64_t frame_index) { // Return the frames at given indices for a given stream OpsFrameBatchOutput get_frames_at_indices( - at::Tensor& decoder, - const at::Tensor& frame_indices) { + torch::stable::Tensor& decoder, + const torch::stable::Tensor& frame_indices) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); auto result = videoDecoder->getFramesAtIndices(frame_indices); return makeOpsFrameBatchOutput(result); @@ -494,7 +684,7 @@ OpsFrameBatchOutput get_frames_at_indices( // Return the frames inside a range as a single stacked Tensor. The range is // defined as [start, stop). OpsFrameBatchOutput get_frames_in_range( - at::Tensor& decoder, + torch::stable::Tensor& decoder, int64_t start, int64_t stop, std::optional step = std::nullopt) { @@ -505,8 +695,8 @@ OpsFrameBatchOutput get_frames_in_range( // Return the frames at given ptss for a given stream OpsFrameBatchOutput get_frames_by_pts( - at::Tensor& decoder, - const at::Tensor& timestamps) { + torch::stable::Tensor& decoder, + const torch::stable::Tensor& timestamps) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); auto result = videoDecoder->getFramesPlayedAt(timestamps); return makeOpsFrameBatchOutput(result); @@ -515,18 +705,21 @@ OpsFrameBatchOutput get_frames_by_pts( // Return the frames inside the range as a single stacked Tensor. The range is // defined as [start_seconds, stop_seconds). The frames are stacked in pts // order. +// If fps is specified, frames are resampled to match the target frame +// rate by duplicating or dropping frames as necessary. OpsFrameBatchOutput get_frames_by_pts_in_range( - at::Tensor& decoder, + torch::stable::Tensor& decoder, double start_seconds, - double stop_seconds) { + double stop_seconds, + std::optional fps = std::nullopt) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); auto result = - videoDecoder->getFramesPlayedInRange(start_seconds, stop_seconds); + videoDecoder->getFramesPlayedInRange(start_seconds, stop_seconds, fps); return makeOpsFrameBatchOutput(result); } OpsAudioFramesOutput get_frames_by_pts_in_range_audio( - at::Tensor& decoder, + torch::stable::Tensor& decoder, double start_seconds, std::optional stop_seconds = std::nullopt) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); @@ -536,9 +729,9 @@ OpsAudioFramesOutput get_frames_by_pts_in_range_audio( } void encode_audio_to_file( - const at::Tensor& samples, + const torch::stable::Tensor& samples, int64_t sample_rate, - std::string_view file_name, + std::string file_name, std::optional bit_rate = std::nullopt, std::optional num_channels = std::nullopt, std::optional desired_sample_rate = std::nullopt) { @@ -556,10 +749,10 @@ void encode_audio_to_file( .encode(); } -at::Tensor encode_audio_to_tensor( - const at::Tensor& samples, +torch::stable::Tensor encode_audio_to_tensor( + const torch::stable::Tensor& samples, int64_t sample_rate, - std::string_view format, + std::string format, std::optional bit_rate = std::nullopt, std::optional num_channels = std::nullopt, std::optional desired_sample_rate = std::nullopt) { @@ -580,16 +773,16 @@ at::Tensor encode_audio_to_tensor( } void _encode_audio_to_file_like( - const at::Tensor& samples, + const torch::stable::Tensor& samples, int64_t sample_rate, - std::string_view format, + std::string format, int64_t file_like_context, std::optional bit_rate = std::nullopt, std::optional num_channels = std::nullopt, std::optional desired_sample_rate = std::nullopt) { auto fileLikeContext = reinterpret_cast(file_like_context); - TORCH_CHECK( + STD_TORCH_CHECK( fileLikeContext != nullptr, "file_like_context must be a valid pointer"); std::unique_ptr avioContextHolder(fileLikeContext); @@ -610,31 +803,52 @@ void _encode_audio_to_file_like( } void encode_video_to_file( - const at::Tensor& frames, - int64_t frame_rate, - std::string_view file_name, - std::optional crf = std::nullopt) { + const torch::stable::Tensor& frames, + double frame_rate, + std::string file_name, + std::optional codec = std::nullopt, + std::optional pixel_format = std::nullopt, + std::optional crf = std::nullopt, + std::optional preset = std::nullopt, + std::optional> extra_options = std::nullopt) { VideoStreamOptions videoStreamOptions; + videoStreamOptions.codec = std::move(codec); + videoStreamOptions.pixelFormat = std::move(pixel_format); videoStreamOptions.crf = crf; - VideoEncoder( - frames, - validateInt64ToInt(frame_rate, "frame_rate"), - file_name, - videoStreamOptions) - .encode(); + videoStreamOptions.preset = preset; + + if (extra_options.has_value()) { + videoStreamOptions.extraOptions = + unflattenExtraOptions(extra_options.value()); + } + + VideoEncoder(frames, frame_rate, file_name, videoStreamOptions).encode(); } -at::Tensor encode_video_to_tensor( - const at::Tensor& frames, - int64_t frame_rate, - std::string_view format, - std::optional crf = std::nullopt) { +torch::stable::Tensor encode_video_to_tensor( + const torch::stable::Tensor& frames, + double frame_rate, + std::string format, + std::optional codec = std::nullopt, + std::optional pixel_format = std::nullopt, + std::optional crf = std::nullopt, + std::optional preset = std::nullopt, + std::optional> extra_options = std::nullopt) { auto avioContextHolder = std::make_unique(); VideoStreamOptions videoStreamOptions; + videoStreamOptions.codec = std::move(codec); + videoStreamOptions.pixelFormat = std::move(pixel_format); videoStreamOptions.crf = crf; + videoStreamOptions.preset = preset; + + if (extra_options.has_value()) { + videoStreamOptions.extraOptions = + unflattenExtraOptions(extra_options.value()); + } + return VideoEncoder( frames, - validateInt64ToInt(frame_rate, "frame_rate"), + frame_rate, format, std::move(avioContextHolder), videoStreamOptions) @@ -642,23 +856,35 @@ at::Tensor encode_video_to_tensor( } void _encode_video_to_file_like( - const at::Tensor& frames, - int64_t frame_rate, - std::string_view format, + const torch::stable::Tensor& frames, + double frame_rate, + std::string format, int64_t file_like_context, - std::optional crf = std::nullopt) { + std::optional codec = std::nullopt, + std::optional pixel_format = std::nullopt, + std::optional crf = std::nullopt, + std::optional preset = std::nullopt, + std::optional> extra_options = std::nullopt) { auto fileLikeContext = reinterpret_cast(file_like_context); - TORCH_CHECK( + STD_TORCH_CHECK( fileLikeContext != nullptr, "file_like_context must be a valid pointer"); std::unique_ptr avioContextHolder(fileLikeContext); VideoStreamOptions videoStreamOptions; + videoStreamOptions.codec = std::move(codec); + videoStreamOptions.pixelFormat = std::move(pixel_format); videoStreamOptions.crf = crf; + videoStreamOptions.preset = preset; + + if (extra_options.has_value()) { + videoStreamOptions.extraOptions = + unflattenExtraOptions(extra_options.value()); + } VideoEncoder encoder( frames, - validateInt64ToInt(frame_rate, "frame_rate"), + frame_rate, format, std::move(avioContextHolder), videoStreamOptions); @@ -675,7 +901,7 @@ void _encode_video_to_file_like( // value when converted to seconds as a double is exactly pts_seconds_to_test. // Returns false otherwise. bool _test_frame_pts_equality( - at::Tensor& decoder, + torch::stable::Tensor& decoder, int64_t frame_index, double pts_seconds_to_test) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); @@ -683,13 +909,13 @@ bool _test_frame_pts_equality( videoDecoder->getPtsSecondsForFrame(frame_index); } -torch::Tensor _get_key_frame_indices(at::Tensor& decoder) { +torch::stable::Tensor _get_key_frame_indices(torch::stable::Tensor& decoder) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); return videoDecoder->getKeyFrameIndices(); } // Get the metadata from the video as a string. -std::string get_json_metadata(at::Tensor& decoder) { +std::string get_json_metadata(torch::stable::Tensor& decoder) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); ContainerMetadata videoMetadata = videoDecoder->getContainerMetadata(); @@ -710,10 +936,10 @@ std::string get_json_metadata(at::Tensor& decoder) { videoMetadata.durationSecondsFromHeader.value_or(0); } metadataMap["durationSecondsFromHeader"] = - std::to_string(durationSecondsFromHeader); + fmt::to_string(durationSecondsFromHeader); if (videoMetadata.bitRate.has_value()) { - metadataMap["bitRate"] = std::to_string(videoMetadata.bitRate.value()); + metadataMap["bitRate"] = fmt::to_string(videoMetadata.bitRate.value()); } if (maybeBestVideoStreamIndex.has_value()) { @@ -728,24 +954,25 @@ std::string get_json_metadata(at::Tensor& decoder) { } if (streamMetadata.beginStreamPtsSecondsFromContent.has_value()) { metadataMap["beginStreamSecondsFromContent"] = - std::to_string(*streamMetadata.beginStreamPtsSecondsFromContent); + fmt::to_string(*streamMetadata.beginStreamPtsSecondsFromContent); } if (streamMetadata.endStreamPtsSecondsFromContent.has_value()) { metadataMap["endStreamSecondsFromContent"] = - std::to_string(*streamMetadata.endStreamPtsSecondsFromContent); + fmt::to_string(*streamMetadata.endStreamPtsSecondsFromContent); } if (streamMetadata.codecName.has_value()) { metadataMap["codec"] = quoteValue(streamMetadata.codecName.value()); } - if (streamMetadata.width.has_value()) { - metadataMap["width"] = std::to_string(*streamMetadata.width); + if (streamMetadata.postRotationWidth.has_value()) { + metadataMap["width"] = std::to_string(*streamMetadata.postRotationWidth); } - if (streamMetadata.height.has_value()) { - metadataMap["height"] = std::to_string(*streamMetadata.height); + if (streamMetadata.postRotationHeight.has_value()) { + metadataMap["height"] = + std::to_string(*streamMetadata.postRotationHeight); } if (streamMetadata.averageFpsFromHeader.has_value()) { metadataMap["averageFpsFromHeader"] = - std::to_string(*streamMetadata.averageFpsFromHeader); + fmt::to_string(*streamMetadata.averageFpsFromHeader); } } if (videoMetadata.bestVideoStreamIndex.has_value()) { @@ -761,7 +988,7 @@ std::string get_json_metadata(at::Tensor& decoder) { } // Get the container metadata as a string. -std::string get_container_json_metadata(at::Tensor& decoder) { +std::string get_container_json_metadata(torch::stable::Tensor& decoder) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); auto containerMetadata = videoDecoder->getContainerMetadata(); @@ -770,11 +997,11 @@ std::string get_container_json_metadata(at::Tensor& decoder) { if (containerMetadata.durationSecondsFromHeader.has_value()) { map["durationSecondsFromHeader"] = - std::to_string(*containerMetadata.durationSecondsFromHeader); + fmt::to_string(*containerMetadata.durationSecondsFromHeader); } if (containerMetadata.bitRate.has_value()) { - map["bitRate"] = std::to_string(*containerMetadata.bitRate); + map["bitRate"] = fmt::to_string(*containerMetadata.bitRate); } if (containerMetadata.bestVideoStreamIndex.has_value()) { @@ -794,27 +1021,28 @@ std::string get_container_json_metadata(at::Tensor& decoder) { // Get the stream metadata as a string. std::string get_stream_json_metadata( - at::Tensor& decoder, + torch::stable::Tensor& decoder, int64_t stream_index) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); auto allStreamMetadata = videoDecoder->getContainerMetadata().allStreamMetadata; - if (stream_index < 0 || - stream_index >= static_cast(allStreamMetadata.size())) { - throw std::out_of_range( - "stream_index out of bounds: " + std::to_string(stream_index)); - } + STABLE_CHECK_INDEX( + stream_index >= 0 && + stream_index < static_cast(allStreamMetadata.size()), + "stream_index out of bounds: " + std::to_string(stream_index)); auto streamMetadata = allStreamMetadata[stream_index]; + auto seekMode = videoDecoder->getSeekMode(); + int activeStreamIndex = videoDecoder->getActiveStreamIndex(); std::map map; if (streamMetadata.durationSecondsFromHeader.has_value()) { map["durationSecondsFromHeader"] = - std::to_string(*streamMetadata.durationSecondsFromHeader); + fmt::to_string(*streamMetadata.durationSecondsFromHeader); } if (streamMetadata.bitRate.has_value()) { - map["bitRate"] = std::to_string(*streamMetadata.bitRate); + map["bitRate"] = fmt::to_string(*streamMetadata.bitRate); } if (streamMetadata.numFramesFromContent.has_value()) { map["numFramesFromContent"] = @@ -826,24 +1054,24 @@ std::string get_stream_json_metadata( } if (streamMetadata.beginStreamSecondsFromHeader.has_value()) { map["beginStreamSecondsFromHeader"] = - std::to_string(*streamMetadata.beginStreamSecondsFromHeader); + fmt::to_string(*streamMetadata.beginStreamSecondsFromHeader); } if (streamMetadata.beginStreamPtsSecondsFromContent.has_value()) { map["beginStreamSecondsFromContent"] = - std::to_string(*streamMetadata.beginStreamPtsSecondsFromContent); + fmt::to_string(*streamMetadata.beginStreamPtsSecondsFromContent); } if (streamMetadata.endStreamPtsSecondsFromContent.has_value()) { map["endStreamSecondsFromContent"] = - std::to_string(*streamMetadata.endStreamPtsSecondsFromContent); + fmt::to_string(*streamMetadata.endStreamPtsSecondsFromContent); } if (streamMetadata.codecName.has_value()) { map["codec"] = quoteValue(streamMetadata.codecName.value()); } - if (streamMetadata.width.has_value()) { - map["width"] = std::to_string(*streamMetadata.width); + if (streamMetadata.postRotationWidth.has_value()) { + map["width"] = std::to_string(*streamMetadata.postRotationWidth); } - if (streamMetadata.height.has_value()) { - map["height"] = std::to_string(*streamMetadata.height); + if (streamMetadata.postRotationHeight.has_value()) { + map["height"] = std::to_string(*streamMetadata.postRotationHeight); } if (streamMetadata.sampleAspectRatio.has_value()) { map["sampleAspectRatioNum"] = @@ -851,9 +1079,24 @@ std::string get_stream_json_metadata( map["sampleAspectRatioDen"] = std::to_string((*streamMetadata.sampleAspectRatio).den); } + if (streamMetadata.rotation.has_value()) { + map["rotation"] = std::to_string(*streamMetadata.rotation); + } + if (auto name = streamMetadata.getColorPrimariesName()) { + map["colorPrimaries"] = quoteValue(*name); + } + if (auto name = streamMetadata.getColorSpaceName()) { + map["colorSpace"] = quoteValue(*name); + } + if (auto name = streamMetadata.getColorTransferCharacteristicName()) { + map["colorTransferCharacteristic"] = quoteValue(*name); + } + if (streamMetadata.pixelFormat.has_value()) { + map["pixelFormat"] = quoteValue(streamMetadata.pixelFormat.value()); + } if (streamMetadata.averageFpsFromHeader.has_value()) { map["averageFpsFromHeader"] = - std::to_string(*streamMetadata.averageFpsFromHeader); + fmt::to_string(*streamMetadata.averageFpsFromHeader); } if (streamMetadata.sampleRate.has_value()) { map["sampleRate"] = std::to_string(*streamMetadata.sampleRate); @@ -871,6 +1114,36 @@ std::string get_stream_json_metadata( } else { map["mediaType"] = quoteValue("other"); } + + // Check whether content-based metadata is available for this stream. + // In exact mode: content-based metadata exists for all streams. + // In approximate mode: content-based metadata does not exist for any stream. + // In custom_frame_mappings: content-based metadata exists only for the active + // stream. + // + // Our fallback logic assumes content-based metadata is available. + // It is available for decoding on the active stream, but would break + // when getting metadata from non-active streams. + if ((seekMode != SeekMode::custom_frame_mappings) || + (seekMode == SeekMode::custom_frame_mappings && + stream_index == activeStreamIndex)) { + writeFallbackBasedMetadata(map, streamMetadata, seekMode); + } else if (seekMode == SeekMode::custom_frame_mappings) { + // If this is not the active stream, then we don't have content-based + // metadata for custom frame mappings. In that case, we want the same + // behavior as we would get with approximate mode. Encoding this behavior in + // the fallback logic itself is tricky and not worth it for this corner + // case. So we hardcode in approximate mode. + // + // TODO: This hacky behavior is only necessary because the custom frame + // mapping is supplied in SingleStreamDecoder::addVideoStream() rather + // than in the constructor. And it's supplied to addVideoStream() and + // not the constructor because we need to know the stream index. If we + // can encode the relevant stream indices into custom frame mappings + // itself, then we can put it in the constructor. + writeFallbackBasedMetadata(map, streamMetadata, SeekMode::approximate); + } + return mapToJson(map); } @@ -906,7 +1179,7 @@ std::string _get_json_ffmpeg_library_versions() { return ss.str(); } -std::string get_backend_details(at::Tensor& decoder) { +std::string get_backend_details(torch::stable::Tensor& decoder) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); return videoDecoder->getDeviceInterfaceDetails(); } @@ -915,7 +1188,7 @@ std::string get_backend_details(at::Tensor& decoder) { // keyframe positions, etc. Exact keyframe positions are useful for efficient // accurate seeking. Note that this function reads the entire video but it does // not decode frames. Reading a video file is much cheaper than decoding it. -void scan_all_streams_to_update_metadata(at::Tensor& decoder) { +void scan_all_streams_to_update_metadata(torch::stable::Tensor& decoder) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); videoDecoder->scanFileAndUpdateMetadataAndIndex(); } @@ -954,9 +1227,18 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { m.impl("_test_frame_pts_equality", &_test_frame_pts_equality); m.impl( "scan_all_streams_to_update_metadata", - &scan_all_streams_to_update_metadata); + TORCH_BOX(&scan_all_streams_to_update_metadata)); - m.impl("_get_backend_details", &get_backend_details); + m.impl("_get_backend_details", TORCH_BOX(&get_backend_details)); + m.impl( + "create_streaming_encoder_to_file", + TORCH_BOX(&create_streaming_encoder_to_file)); + m.impl("streaming_encoder_close", TORCH_BOX(&streaming_encoder_close)); + m.impl( + "streaming_encoder_add_video_stream", + TORCH_BOX(&streaming_encoder_add_video_stream)); + m.impl( + "streaming_encoder_add_frames", TORCH_BOX(&streaming_encoder_add_frames)); } } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/fetch_and_expose_non_gpl_ffmpeg_libs.cmake b/src/torchcodec/_core/fetch_and_expose_non_gpl_ffmpeg_libs.cmake index 07abd2e87..293b217c4 100644 --- a/src/torchcodec/_core/fetch_and_expose_non_gpl_ffmpeg_libs.cmake +++ b/src/torchcodec/_core/fetch_and_expose_non_gpl_ffmpeg_libs.cmake @@ -9,45 +9,64 @@ endif() include(FetchContent) -if (UNIX AND NOT APPLE) - set(LINUX TRUE) -else() - set(LINUX FALSE) -endif() - set( base_url https://pytorch.s3.amazonaws.com/torchcodec/ffmpeg/2025-03-14 ) if (LINUX) - set(lib_dir "lib") - - set( - platform_url - ${base_url}/linux_x86_64 - ) - - set( - f4_sha256 - 1a083f1922443bedb5243d04896383b8c606778a7ddb9d886c8303e55339fe0c - ) - set( - f5_sha256 - 65d6ad54082d94dcb3f801d73df2265e0e1bb303c7afbce7723e3b77ccd0e207 - ) - set( - f6_sha256 - 8bd5939c2f4a4b072e837e7870c13fe7d13824e5ff087ab534e4db4e90b7be9c - ) - set( - f7_sha256 - 1cb946d8b7c6393c2c3ebe1f900b8de7a2885fe614c45d4ec32c9833084f2f26 - ) - set( - f8_sha256 - c55b3c1a4b5e4d5fdd7c632bea3ab6f45b4e37cc8e0999dda3f84a8ed8defad8 - ) + if (CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64|ARM64") + set( + platform_url + ${base_url}/linux_aarch64 + ) + set( + f4_sha256 + a310a2ed9ffe555fd3278dae15065541098dd35e124564671dcda6a6620ac842 + ) + set( + f5_sha256 + 89ca7996bccbc2db49adaa401d20fdbabffe0e1b4e07a0f81d6b143e858b7c8d + ) + set( + f6_sha256 + ae44c67b4587d061b8e9cc8990ca891ee013fe52ad79e5016ba29871562621da + ) + set( + f7_sha256 + 948e2cac66ca6f68ff526d5e84138e94bce0f1a7c83f502d15d85d0bd3ddc112 + ) + set( + f8_sha256 + b9cfd99ae75a14e58300854967d4dc49de0b3daa551df51ea1f52a3f08d2c8af + ) + else() + set( + platform_url + ${base_url}/linux_x86_64 + ) + + set( + f4_sha256 + 1a083f1922443bedb5243d04896383b8c606778a7ddb9d886c8303e55339fe0c + ) + set( + f5_sha256 + 65d6ad54082d94dcb3f801d73df2265e0e1bb303c7afbce7723e3b77ccd0e207 + ) + set( + f6_sha256 + 8bd5939c2f4a4b072e837e7870c13fe7d13824e5ff087ab534e4db4e90b7be9c + ) + set( + f7_sha256 + 1cb946d8b7c6393c2c3ebe1f900b8de7a2885fe614c45d4ec32c9833084f2f26 + ) + set( + f8_sha256 + c55b3c1a4b5e4d5fdd7c632bea3ab6f45b4e37cc8e0999dda3f84a8ed8defad8 + ) + endif() set( f4_library_file_names libavutil.so.56 @@ -99,7 +118,6 @@ if (LINUX) libswresample.so.6 ) elseif (APPLE) - set(lib_dir "lib") set( platform_url ${base_url}/macos_arm64 @@ -124,60 +142,7 @@ elseif (APPLE) f8_sha256 beb936b76f25d2621228a12cdb67c9ae3d1eff7aa713ef8d1167ebf0c25bd5ec ) - - set( - f4_library_file_names - libavutil.56.dylib - libavcodec.58.dylib - libavformat.58.dylib - libavdevice.58.dylib - libavfilter.7.dylib - libswscale.5.dylib - libswresample.3.dylib - ) - set( - f5_library_file_names - libavutil.57.dylib - libavcodec.59.dylib - libavformat.59.dylib - libavdevice.59.dylib - libavfilter.8.dylib - libswscale.6.dylib - libswresample.4.dylib - ) - set( - f6_library_file_names - libavutil.58.dylib - libavcodec.60.dylib - libavformat.60.dylib - libavdevice.60.dylib - libavfilter.9.dylib - libswscale.7.dylib - libswresample.4.dylib - ) - set( - f7_library_file_names - libavutil.59.dylib - libavcodec.61.dylib - libavformat.61.dylib - libavdevice.61.dylib - libavfilter.10.dylib - libswscale.8.dylib - libswresample.5.dylib - ) - set( - f8_library_file_names - libavutil.60.dylib - libavcodec.62.dylib - libavformat.62.dylib - libavdevice.62.dylib - libavfilter.11.dylib - libswscale.9.dylib - libswresample.6.dylib - ) - elseif (WIN32) - set(lib_dir "bin") set( platform_url ${base_url}/windows_x86_64 @@ -202,57 +167,6 @@ elseif (WIN32) f8_sha256 bac845ac79876b104959cb0e7b9dec772a261116344dd17d2f97e7ddfac4a73f ) - - set( - f4_library_file_names - avutil.lib - avcodec.lib - avformat.lib - avdevice.lib - avfilter.lib - swscale.lib - swresample.lib - ) - set( - f5_library_file_names - avutil.lib - avcodec.lib - avformat.lib - avdevice.lib - avfilter.lib - swscale.lib - swresample.lib - ) - set( - f6_library_file_names - avutil.lib - avcodec.lib - avformat.lib - avdevice.lib - avfilter.lib - swscale.lib - swresample.lib - ) - set( - f7_library_file_names - avutil.lib - avcodec.lib - avformat.lib - avdevice.lib - avfilter.lib - swscale.lib - swresample.lib - ) - set( - f8_library_file_names - avutil.lib - avcodec.lib - avformat.lib - avdevice.lib - avfilter.lib - swscale.lib - swresample.lib - ) else() message( FATAL_ERROR @@ -293,68 +207,12 @@ FetchContent_Declare( FetchContent_MakeAvailable(f4 f5 f6 f7 f8) -add_library(ffmpeg4 INTERFACE) -add_library(ffmpeg5 INTERFACE) -add_library(ffmpeg6 INTERFACE) -add_library(ffmpeg7 INTERFACE) -add_library(ffmpeg8 INTERFACE) +# makes add_ffmpeg_target available +include("${CMAKE_CURRENT_SOURCE_DIR}/../share/cmake/TorchCodec/ffmpeg_versions.cmake") # Note: the f?_SOURCE_DIR variables were set by FetchContent_MakeAvailable -target_include_directories(ffmpeg4 INTERFACE ${f4_SOURCE_DIR}/include) -target_include_directories(ffmpeg5 INTERFACE ${f5_SOURCE_DIR}/include) -target_include_directories(ffmpeg6 INTERFACE ${f6_SOURCE_DIR}/include) -target_include_directories(ffmpeg7 INTERFACE ${f7_SOURCE_DIR}/include) -target_include_directories(ffmpeg8 INTERFACE ${f8_SOURCE_DIR}/include) - - -list( - TRANSFORM f4_library_file_names - PREPEND ${f4_SOURCE_DIR}/${lib_dir}/ - OUTPUT_VARIABLE f4_library_paths -) -list( - TRANSFORM f5_library_file_names - PREPEND ${f5_SOURCE_DIR}/${lib_dir}/ - OUTPUT_VARIABLE f5_library_paths -) -list( - TRANSFORM f6_library_file_names - PREPEND ${f6_SOURCE_DIR}/${lib_dir}/ - OUTPUT_VARIABLE f6_library_paths -) -list( - TRANSFORM f7_library_file_names - PREPEND ${f7_SOURCE_DIR}/${lib_dir}/ - OUTPUT_VARIABLE f7_library_paths -) -list( - TRANSFORM f8_library_file_names - PREPEND ${f8_SOURCE_DIR}/${lib_dir}/ - OUTPUT_VARIABLE f8_library_paths -) - -target_link_libraries( - ffmpeg4 - INTERFACE - ${f4_library_paths} -) -target_link_libraries( - ffmpeg5 - INTERFACE - ${f5_library_paths} -) -target_link_libraries( - ffmpeg6 - INTERFACE - ${f6_library_paths} -) -target_link_libraries( - ffmpeg7 - INTERFACE - ${f7_library_paths} -) -target_link_libraries( - ffmpeg8 - INTERFACE - ${f8_library_paths} -) +add_ffmpeg_target(4 "${f4_SOURCE_DIR}") +add_ffmpeg_target(5 "${f5_SOURCE_DIR}") +add_ffmpeg_target(6 "${f6_SOURCE_DIR}") +add_ffmpeg_target(7 "${f7_SOURCE_DIR}") +add_ffmpeg_target(8 "${f8_SOURCE_DIR}") diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index 9ea928890..353c0d24d 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -4,85 +4,41 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + import io import json +import os +import shutil +import sys import warnings -from types import ModuleType -from typing import List, Optional, Tuple, Union +from contextlib import nullcontext +from pathlib import Path import torch # from torch.library import get_ctx, register_fake from torch.library import register_fake from torchcodec._internally_replaced_utils import ( # @manual=//pytorch/torchcodec/src:internally_replaced_utils - _get_extension_path, - _get_pybind_ops_module_name, - _load_pybind11_module, + load_torchcodec_shared_libraries, ) -_pybind_ops: Optional[ModuleType] = None - - -def load_torchcodec_shared_libraries(): - # Successively try to load the shared libraries for each version of FFmpeg - # that we support. We always start with the highest version, working our way - # down to the lowest version. Once we can load ALL shared libraries for a - # version of FFmpeg, we have succeeded and we stop. - # - # Note that we use two different methods for loading shared libraries: - # - # 1. torch.ops.load_library(): For PyTorch custom ops and the C++ only - # libraries the custom ops depend on. Loading libraries through PyTorch - # registers the custom ops with PyTorch's runtime and the ops can be - # accessed through torch.ops after loading. - # - # 2. importlib: For pybind11 modules. We load them dynamically, rather - # than using a plain import statement. A plain import statement only - # works when the module name and file name match exactly. Our shared - # libraries do not meet those conditions. - - exceptions = [] - for ffmpeg_major_version in (8, 7, 6, 5, 4): - pybind_ops_module_name = _get_pybind_ops_module_name(ffmpeg_major_version) - decoder_library_name = f"libtorchcodec_core{ffmpeg_major_version}" - custom_ops_library_name = f"libtorchcodec_custom_ops{ffmpeg_major_version}" - pybind_ops_library_name = f"libtorchcodec_pybind_ops{ffmpeg_major_version}" - try: - torch.ops.load_library(_get_extension_path(decoder_library_name)) - torch.ops.load_library(_get_extension_path(custom_ops_library_name)) - - pybind_ops_library_path = _get_extension_path(pybind_ops_library_name) - global _pybind_ops - _pybind_ops = _load_pybind11_module( - pybind_ops_module_name, pybind_ops_library_path - ) - return - except Exception as e: - # TODO: recording and reporting exceptions this way is OK for now as it's just for debugging, - # but we should probably handle that via a proper logging mechanism. - exceptions.append((ffmpeg_major_version, e)) - - traceback = ( - "\n[start of libtorchcodec loading traceback]\n" - + "\n".join(f"FFmpeg version {v}: {str(e)}" for v, e in exceptions) - + "\n[end of libtorchcodec loading traceback]." - ) - raise RuntimeError( - f"""Could not load libtorchcodec. Likely causes: - 1. FFmpeg is not properly installed in your environment. We support - versions 4, 5, 6, and 7 on all platforms, and 8 on Mac and Linux. - 2. The PyTorch version ({torch.__version__}) is not compatible with - this version of TorchCodec. Refer to the version compatibility - table: - https://github.com/pytorch/torchcodec?tab=readme-ov-file#installing-torchcodec. - 3. Another runtime dependency; see exceptions below. - The following exceptions were raised as we tried to load libtorchcodec: - """ - f"{traceback}" - ) +expose_ffmpeg_dlls = nullcontext +if sys.platform == "win32" and hasattr(os, "add_dll_directory"): + # On windows we try to locate the FFmpeg DLLs and temporarily add them to + # the DLL search path. This seems to be needed on some users machine, but + # not on our CI. We don't know why. + if ffmpeg_path := shutil.which("ffmpeg"): + + def expose_ffmpeg_dlls(): # noqa: F811 + ffmpeg_dir = Path(ffmpeg_path).parent.absolute() + return os.add_dll_directory(str(ffmpeg_dir)) # that's the actual CM -load_torchcodec_shared_libraries() + +with expose_ffmpeg_dlls(): + ffmpeg_major_version, core_library_path, _pybind_ops = ( + load_torchcodec_shared_libraries() + ) import types @@ -124,8 +80,46 @@ def disallow_in_graph(self, fn): _create_from_file_like = torch._dynamo.disallow_in_graph( torch.ops.torchcodec_ns._create_from_file_like.default ) -add_video_stream = torch.ops.torchcodec_ns.add_video_stream.default +_add_video_stream_raw = torch.ops.torchcodec_ns.add_video_stream.default _add_video_stream = torch.ops.torchcodec_ns._add_video_stream.default + + +def add_video_stream( + decoder: torch.Tensor, + *, + num_threads: int | None = None, + dimension_order: str | None = None, + stream_index: int | None = None, + device: str = "cpu", + device_variant: str = "ffmpeg", + transform_specs: str = "", + custom_frame_mappings: ( + tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None + ) = None, +) -> None: + custom_frame_mappings_pts: torch.Tensor | None = None + custom_frame_mappings_keyframe_indices: torch.Tensor | None = None + custom_frame_mappings_duration: torch.Tensor | None = None + if custom_frame_mappings is not None: + ( + custom_frame_mappings_pts, + custom_frame_mappings_keyframe_indices, + custom_frame_mappings_duration, + ) = custom_frame_mappings + _add_video_stream_raw( + decoder, + num_threads=num_threads, + dimension_order=dimension_order, + stream_index=stream_index, + device=device, + device_variant=device_variant, + transform_specs=transform_specs, + custom_frame_mappings_pts=custom_frame_mappings_pts, + custom_frame_mappings_duration=custom_frame_mappings_duration, + custom_frame_mappings_keyframe_indices=custom_frame_mappings_keyframe_indices, + ) + + add_audio_stream = torch.ops.torchcodec_ns.add_audio_stream.default seek_to_pts = torch.ops.torchcodec_ns.seek_to_pts.default get_next_frame = torch.ops.torchcodec_ns.get_next_frame.default @@ -154,14 +148,32 @@ def disallow_in_graph(self, fn): torch.ops.torchcodec_ns._get_json_ffmpeg_library_versions.default ) _get_backend_details = torch.ops.torchcodec_ns._get_backend_details.default +create_streaming_encoder_to_file = torch._dynamo.disallow_in_graph( + torch.ops.torchcodec_ns.create_streaming_encoder_to_file.default +) +streaming_encoder_close = torch.ops.torchcodec_ns.streaming_encoder_close.default +streaming_encoder_add_video_stream = ( + torch.ops.torchcodec_ns.streaming_encoder_add_video_stream.default +) +streaming_encoder_add_frames = ( + torch.ops.torchcodec_ns.streaming_encoder_add_frames.default +) +set_nvdec_cache_capacity = torch.ops.torchcodec_ns.set_nvdec_cache_capacity.default +get_nvdec_cache_capacity = torch.ops.torchcodec_ns.get_nvdec_cache_capacity.default +_get_nvdec_cache_size = torch.ops.torchcodec_ns._get_nvdec_cache_size.default +create_wav_decoder_from_file = ( + torch.ops.torchcodec_ns.create_wav_decoder_from_file.default +) +get_wav_all_samples = torch.ops.torchcodec_ns.get_wav_all_samples.default +get_wav_metadata_from_decoder = ( + torch.ops.torchcodec_ns.get_wav_metadata_from_decoder.default +) # ============================= # Functions not related to custom ops, but similar implementation to c++ ops # ============================= -def create_from_bytes( - video_bytes: bytes, seek_mode: Optional[str] = None -) -> torch.Tensor: +def create_from_bytes(video_bytes: bytes, seek_mode: str | None = None) -> torch.Tensor: with warnings.catch_warnings(): # Ignore warning stating that the underlying video_bytes buffer is # non-writable. @@ -173,7 +185,7 @@ def create_from_bytes( def create_from_file_like( - file_like: Union[io.RawIOBase, io.BufferedReader], seek_mode: Optional[str] = None + file_like: io.RawIOBase | io.BufferedReader, seek_mode: str | None = None ) -> torch.Tensor: assert _pybind_ops is not None return _create_from_file_like( @@ -188,10 +200,10 @@ def encode_audio_to_file_like( samples: torch.Tensor, sample_rate: int, format: str, - file_like: Union[io.RawIOBase, io.BufferedIOBase], - bit_rate: Optional[int] = None, - num_channels: Optional[int] = None, - desired_sample_rate: Optional[int] = None, + file_like: io.RawIOBase | io.BufferedIOBase, + bit_rate: int | None = None, + num_channels: int | None = None, + desired_sample_rate: int | None = None, ) -> None: """Encode audio samples to a file-like object. @@ -222,19 +234,27 @@ def encode_audio_to_file_like( def encode_video_to_file_like( frames: torch.Tensor, - frame_rate: int, + frame_rate: float, format: str, - file_like: Union[io.RawIOBase, io.BufferedIOBase], - crf: Optional[int] = None, + file_like: io.RawIOBase | io.BufferedIOBase, + codec: str | None = None, + pixel_format: str | None = None, + crf: int | float | None = None, + preset: str | None = None, + extra_options: list[str] | None = None, ) -> None: """Encode video frames to a file-like object. Args: - frames: Video frames tensor + frames: Video frames tensor. The device of the frames tensor will be used for encoding. frame_rate: Frame rate in frames per second format: Video format (e.g., "mp4", "mov", "mkv") file_like: File-like object that supports write() and seek() methods + codec: Optional codec name (e.g., "libx264", "h264") + pixel_format: Optional pixel format (e.g., "yuv420p", "yuv444p") crf: Optional constant rate factor for encoding quality + preset: Optional encoder preset as string (e.g., "ultrafast", "medium") + extra_options: Optional list of extra options as flattened key-value pairs """ assert _pybind_ops is not None @@ -243,13 +263,17 @@ def encode_video_to_file_like( frame_rate, format, _pybind_ops.create_file_like_context(file_like, True), # True means for writing + codec, + pixel_format, crf, + preset, + extra_options, ) def get_frames_at_indices( - decoder: torch.Tensor, *, frame_indices: Union[torch.Tensor, list[int]] -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + decoder: torch.Tensor, *, frame_indices: torch.Tensor | list[int] +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if isinstance(frame_indices, torch.Tensor): # Ensure indices is the correct dtype (int64) frame_indices = frame_indices.to(torch.int64) @@ -260,8 +284,8 @@ def get_frames_at_indices( def get_frames_by_pts( - decoder: torch.Tensor, *, timestamps: Union[torch.Tensor, list[float]] -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + decoder: torch.Tensor, *, timestamps: torch.Tensor | list[float] +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if isinstance(timestamps, torch.Tensor): # Ensure indices is the correct dtype (float64) timestamps = timestamps.to(torch.float64) @@ -278,13 +302,13 @@ def get_frames_by_pts( # Abstract impl for the operators. Needed by torch.compile. # ============================== @register_fake("torchcodec_ns::create_from_file") -def create_from_file_abstract(filename: str, seek_mode: Optional[str]) -> torch.Tensor: +def create_from_file_abstract(filename: str, seek_mode: str | None) -> torch.Tensor: return torch.empty([], dtype=torch.long) @register_fake("torchcodec_ns::_create_from_file_like") def _create_from_file_like_abstract( - file_like: int, seek_mode: Optional[str] + file_like: int, seek_mode: str | None ) -> torch.Tensor: return torch.empty([], dtype=torch.long) @@ -294,9 +318,9 @@ def encode_audio_to_file_abstract( samples: torch.Tensor, sample_rate: int, filename: str, - bit_rate: Optional[int] = None, - num_channels: Optional[int] = None, - desired_sample_rate: Optional[int] = None, + bit_rate: int | None = None, + num_channels: int | None = None, + desired_sample_rate: int | None = None, ) -> None: return @@ -306,9 +330,9 @@ def encode_audio_to_tensor_abstract( samples: torch.Tensor, sample_rate: int, format: str, - bit_rate: Optional[int] = None, - num_channels: Optional[int] = None, - desired_sample_rate: Optional[int] = None, + bit_rate: int | None = None, + num_channels: int | None = None, + desired_sample_rate: int | None = None, ) -> torch.Tensor: return torch.empty([], dtype=torch.long) @@ -319,9 +343,9 @@ def _encode_audio_to_file_like_abstract( sample_rate: int, format: str, file_like_context: int, - bit_rate: Optional[int] = None, - num_channels: Optional[int] = None, - desired_sample_rate: Optional[int] = None, + bit_rate: int | None = None, + num_channels: int | None = None, + desired_sample_rate: int | None = None, ) -> None: return @@ -329,9 +353,13 @@ def _encode_audio_to_file_like_abstract( @register_fake("torchcodec_ns::encode_video_to_file") def encode_video_to_file_abstract( frames: torch.Tensor, - frame_rate: int, + frame_rate: float, filename: str, - crf: Optional[int], + codec: str | None = None, + pixel_format: str | None = None, + preset: str | None = None, + crf: int | float | None = None, + extra_options: list[str] | None = None, ) -> None: return @@ -339,9 +367,13 @@ def encode_video_to_file_abstract( @register_fake("torchcodec_ns::encode_video_to_tensor") def encode_video_to_tensor_abstract( frames: torch.Tensor, - frame_rate: int, + frame_rate: float, format: str, - crf: Optional[int], + codec: str | None = None, + pixel_format: str | None = None, + preset: str | None = None, + crf: int | float | None = None, + extra_options: list[str] | None = None, ) -> torch.Tensor: return torch.empty([], dtype=torch.long) @@ -349,17 +381,21 @@ def encode_video_to_tensor_abstract( @register_fake("torchcodec_ns::_encode_video_to_file_like") def _encode_video_to_file_like_abstract( frames: torch.Tensor, - frame_rate: int, + frame_rate: float, format: str, file_like_context: int, - crf: Optional[int] = None, + codec: str | None = None, + pixel_format: str | None = None, + preset: str | None = None, + crf: int | float | None = None, + extra_options: list[str] | None = None, ) -> None: return @register_fake("torchcodec_ns::create_from_tensor") def create_from_tensor_abstract( - video_tensor: torch.Tensor, seek_mode: Optional[str] + video_tensor: torch.Tensor, seek_mode: str | None ) -> torch.Tensor: return torch.empty([], dtype=torch.long) @@ -368,16 +404,16 @@ def create_from_tensor_abstract( def _add_video_stream_abstract( decoder: torch.Tensor, *, - num_threads: Optional[int] = None, - dimension_order: Optional[str] = None, - stream_index: Optional[int] = None, + num_threads: int | None = None, + dimension_order: str | None = None, + stream_index: int | None = None, device: str = "cpu", device_variant: str = "ffmpeg", transform_specs: str = "", - custom_frame_mappings: Optional[ - tuple[torch.Tensor, torch.Tensor, torch.Tensor] - ] = None, - color_conversion_library: Optional[str] = None, + custom_frame_mappings_pts: torch.Tensor | None = None, + custom_frame_mappings_duration: torch.Tensor | None = None, + custom_frame_mappings_keyframe_indices: torch.Tensor | None = None, + color_conversion_library: str | None = None, ) -> None: return @@ -386,15 +422,15 @@ def _add_video_stream_abstract( def add_video_stream_abstract( decoder: torch.Tensor, *, - num_threads: Optional[int] = None, - dimension_order: Optional[str] = None, - stream_index: Optional[int] = None, + num_threads: int | None = None, + dimension_order: str | None = None, + stream_index: int | None = None, device: str = "cpu", device_variant: str = "ffmpeg", transform_specs: str = "", - custom_frame_mappings: Optional[ - tuple[torch.Tensor, torch.Tensor, torch.Tensor] - ] = None, + custom_frame_mappings_pts: torch.Tensor | None = None, + custom_frame_mappings_duration: torch.Tensor | None = None, + custom_frame_mappings_keyframe_indices: torch.Tensor | None = None, ) -> None: return @@ -403,9 +439,9 @@ def add_video_stream_abstract( def add_audio_stream_abstract( decoder: torch.Tensor, *, - stream_index: Optional[int] = None, - sample_rate: Optional[int] = None, - num_channels: Optional[int] = None, + stream_index: int | None = None, + sample_rate: int | None = None, + num_channels: int | None = None, ) -> None: return @@ -418,7 +454,7 @@ def seek_abstract(decoder: torch.Tensor, seconds: float) -> None: @register_fake("torchcodec_ns::get_next_frame") def get_next_frame_abstract( decoder: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Images are 3 dimensions: height, width, channels. # The exact permutation depends on the constructor options passed in. image_size = [get_ctx().new_dynamic_size() for _ in range(3)] @@ -432,7 +468,7 @@ def get_next_frame_abstract( @register_fake("torchcodec_ns::get_frame_at_pts") def get_frame_at_pts_abstract( decoder: torch.Tensor, seconds: float -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: image_size = [get_ctx().new_dynamic_size() for _ in range(3)] return ( torch.empty(image_size), @@ -445,8 +481,8 @@ def get_frame_at_pts_abstract( def get_frames_by_pts_abstract( decoder: torch.Tensor, *, - timestamps: Union[torch.Tensor, List[float]], -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + timestamps: torch.Tensor | list[float], +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: image_size = [get_ctx().new_dynamic_size() for _ in range(4)] return ( torch.empty(image_size), @@ -458,7 +494,7 @@ def get_frames_by_pts_abstract( @register_fake("torchcodec_ns::get_frame_at_index") def get_frame_at_index_abstract( decoder: torch.Tensor, *, frame_index: int -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: image_size = [get_ctx().new_dynamic_size() for _ in range(3)] return ( torch.empty(image_size), @@ -469,8 +505,8 @@ def get_frame_at_index_abstract( @register_fake("torchcodec_ns::get_frames_at_indices") def get_frames_at_indices_abstract( - decoder: torch.Tensor, *, frame_indices: Union[torch.Tensor, List[int]] -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + decoder: torch.Tensor, *, frame_indices: torch.Tensor | list[int] +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: image_size = [get_ctx().new_dynamic_size() for _ in range(4)] return ( torch.empty(image_size), @@ -485,8 +521,8 @@ def get_frames_in_range_abstract( *, start: int, stop: int, - step: Optional[int] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + step: int | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: image_size = [get_ctx().new_dynamic_size() for _ in range(4)] return ( torch.empty(image_size), @@ -501,7 +537,8 @@ def get_frames_by_pts_in_range_abstract( *, start_seconds: float, stop_seconds: float, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + fps: float | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: image_size = [get_ctx().new_dynamic_size() for _ in range(4)] return ( torch.empty(image_size), @@ -515,8 +552,8 @@ def get_frames_by_pts_in_range_audio_abstract( decoder: torch.Tensor, *, start_seconds: float, - stop_seconds: Optional[float] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: + stop_seconds: float | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: image_size = [get_ctx().new_dynamic_size() for _ in range(4)] return (torch.empty(image_size), torch.empty([], dtype=torch.float)) @@ -569,3 +606,55 @@ def get_ffmpeg_library_versions(): @register_fake("torchcodec_ns::_get_backend_details") def _get_backend_details_abstract(decoder: torch.Tensor) -> str: return "" + + +@register_fake("torchcodec_ns::create_streaming_encoder_to_file") +def _create_streaming_encoder_to_file_abstract( + filename: str, +) -> torch.Tensor: + return torch.empty([], dtype=torch.long) + + +@register_fake("torchcodec_ns::streaming_encoder_close") +def streaming_encoder_close_abstract(encoder: torch.Tensor) -> None: + return + + +@register_fake("torchcodec_ns::streaming_encoder_add_video_stream") +def streaming_encoder_add_video_stream_abstract( + encoder: torch.Tensor, + frame_rate: float, + codec: str | None = None, + pixel_format: str | None = None, + crf: float | None = None, + preset: str | None = None, + extra_options: list[str] | None = None, +) -> None: + return + + +@register_fake("torchcodec_ns::streaming_encoder_add_frames") +def streaming_encoder_add_frames_abstract( + encoder: torch.Tensor, frames: torch.Tensor +) -> None: + return + + +@register_fake("torchcodec_ns::set_nvdec_cache_capacity") +def set_nvdec_cache_capacity_abstract(capacity: int) -> None: + return + + +@register_fake("torchcodec_ns::get_nvdec_cache_capacity") +def get_nvdec_cache_capacity_abstract() -> int: + return 0 + + +@register_fake("torchcodec_ns::_get_nvdec_cache_size") +def _get_nvdec_cache_size_abstract(device_index: int) -> int: + return 0 + + +@register_fake("torchcodec_ns::get_wav_metadata_from_decoder") +def get_wav_metadata_from_decoder_abstract(decoder: torch.Tensor) -> str: + return "" diff --git a/src/torchcodec/_core/pybind_ops.cpp b/src/torchcodec/_core/pybind_ops.cpp index f9cfd7eb6..1f2e792ee 100644 --- a/src/torchcodec/_core/pybind_ops.cpp +++ b/src/torchcodec/_core/pybind_ops.cpp @@ -8,7 +8,7 @@ #include #include -#include "src/torchcodec/_core/AVIOFileLikeContext.h" +#include "AVIOFileLikeContext.h" namespace py = pybind11; diff --git a/src/torchcodec/_frame.py b/src/torchcodec/_frame.py index b5d7d9d5a..2ceb890b7 100644 --- a/src/torchcodec/_frame.py +++ b/src/torchcodec/_frame.py @@ -4,9 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + import dataclasses +from collections.abc import Iterable, Iterator from dataclasses import dataclass -from typing import Iterable, Iterator, Union from torch import Tensor @@ -46,7 +47,7 @@ def __post_init__(self): self.pts_seconds = float(self.pts_seconds) self.duration_seconds = float(self.duration_seconds) - def __iter__(self) -> Iterator[Union[Tensor, float]]: + def __iter__(self) -> Iterator[Tensor | float]: for field in dataclasses.fields(self): yield getattr(self, field.name) @@ -137,7 +138,7 @@ def __post_init__(self): self.pts_seconds = float(self.pts_seconds) self.sample_rate = int(self.sample_rate) - def __iter__(self) -> Iterator[Union[Tensor, float]]: + def __iter__(self) -> Iterator[Tensor | float]: for field in dataclasses.fields(self): yield getattr(self, field.name) diff --git a/src/torchcodec/_internally_replaced_utils.py b/src/torchcodec/_internally_replaced_utils.py index bf8ac3ac5..749908661 100644 --- a/src/torchcodec/_internally_replaced_utils.py +++ b/src/torchcodec/_internally_replaced_utils.py @@ -5,16 +5,24 @@ # LICENSE file in the root directory of this source tree. import importlib +import importlib.util import sys +import traceback from pathlib import Path from types import ModuleType +import torch + +# Note that this value must match the value used as PYBIND_OPS_MODULE_NAME when we compile _core/pybind_ops.cpp. +# If the values do not match, we will not be able to import the C++ shared library as a Python module at runtime. +_PYBIND_OPS_MODULE_NAME = "core_pybind_ops" + # Copy pasted from torchvision # https://github.com/pytorch/vision/blob/947ae1dc71867f28021d5bc0ff3a19c249236e2a/torchvision/_internally_replaced_utils.py#L25 def _get_extension_path(lib_name: str) -> str: extension_suffixes = [] - if sys.platform == "linux": + if sys.platform == "linux" or sys.platform.startswith("freebsd"): extension_suffixes = importlib.machinery.EXTENSION_SUFFIXES elif sys.platform == "darwin": extension_suffixes = importlib.machinery.EXTENSION_SUFFIXES + [".dylib"] @@ -56,12 +64,65 @@ def _load_pybind11_module(module_name: str, library_path: str) -> ModuleType: return mod -# Note that the return value from this function must match the value used as -# PYBIND_OPS_MODULE_NAME when we compile _core/pybind_ops.cpp. If the values -# do not match, we will not be able to import the C++ shared library as a -# Python module at runtime. -# -# The parameter ffmpeg_major_version is unused externally, but used -# internally. -def _get_pybind_ops_module_name(ffmpeg_major_version: int) -> str: - return "core_pybind_ops" +def load_torchcodec_shared_libraries() -> tuple[int, str, ModuleType]: + """ + Successively try to load the shared libraries for each version of FFmpeg + that we support. We always start with the highest version, working our way + down to the lowest version. Once we can load ALL shared libraries for a + version of FFmpeg, we have succeeded and we stop. + + Note that we use two different methods for loading shared libraries: + + 1. torch.ops.load_library(): For PyTorch custom ops and the C++ only + libraries the custom ops depend on. Loading libraries through PyTorch + registers the custom ops with PyTorch's runtime and the ops can be + accessed through torch.ops after loading. + + 2. importlib: For pybind11 modules. We load them dynamically, rather + than using a plain import statement. A plain import statement only + works when the module name and file name match exactly. Our shared + libraries do not meet those conditions. + """ + exceptions = [] + for ffmpeg_major_version in (8, 7, 6, 5, 4): + core_library_name = f"libtorchcodec_core{ffmpeg_major_version}" + custom_ops_library_name = f"libtorchcodec_custom_ops{ffmpeg_major_version}" + pybind_ops_library_name = f"libtorchcodec_pybind_ops{ffmpeg_major_version}" + try: + core_library_path = _get_extension_path(core_library_name) + torch.ops.load_library(core_library_path) + torch.ops.load_library(_get_extension_path(custom_ops_library_name)) + + pybind_ops_library_path = _get_extension_path(pybind_ops_library_name) + pybind_ops = _load_pybind11_module( + _PYBIND_OPS_MODULE_NAME, pybind_ops_library_path + ) + return ffmpeg_major_version, core_library_path, pybind_ops + except Exception: + # Capture the full traceback for this exception + exc_traceback = traceback.format_exc() + exceptions.append((ffmpeg_major_version, exc_traceback)) + + traceback_info = ( + "\n[start of libtorchcodec loading traceback]\n" + + "\n".join(f"FFmpeg version {v}:\n{tb}" for v, tb in exceptions) + + "[end of libtorchcodec loading traceback]." + ) + raise RuntimeError( + f"""Could not load libtorchcodec. Likely causes: + 1. FFmpeg is not properly installed in your environment. We support + versions 4, 5, 6, 7, and 8, and we attempt to load libtorchcodec + for each of those versions. Errors for versions not installed on + your system are expected; only the error for your installed FFmpeg + version is relevant. On Windows, ensure you've installed the + "full-shared" version which ships DLLs. + 2. The PyTorch version ({torch.__version__}) is not compatible with + this version of TorchCodec. Refer to the version compatibility + table: + https://github.com/pytorch/torchcodec?tab=readme-ov-file#installing-torchcodec. + 3. Another runtime dependency; see exceptions below. + + The following exceptions were raised as we tried to load libtorchcodec: + """ + f"{traceback_info}" + ) diff --git a/src/torchcodec/_samplers/video_clip_sampler.py b/src/torchcodec/_samplers/video_clip_sampler.py index 343728393..43ca32def 100644 --- a/src/torchcodec/_samplers/video_clip_sampler.py +++ b/src/torchcodec/_samplers/video_clip_sampler.py @@ -4,21 +4,20 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + import abc import json import sys from dataclasses import dataclass, field -from typing import Any, Dict, List, Tuple, Union +from typing import Any import torch from torch import nn, Tensor -from torchcodec._core import ( +from torchcodec._core import get_frames_at_indices, get_json_metadata, get_next_frame +from torchcodec._core.ops import ( add_video_stream, create_from_tensor, - get_frames_at_indices, - get_json_metadata, - get_next_frame, scan_all_streams_to_update_metadata, seek_to_pts, ) @@ -82,7 +81,7 @@ class TimeBasedSamplerArgs(SamplerArgs): sample_start_second: float = 0.0 sample_end_second: float = float("inf") sample_per_second: float = 0.0 - target_sample_start_second: List[float] = field(default_factory=lambda: []) + target_sample_start_second: list[float] = field(default_factory=lambda: []) @dataclass @@ -117,21 +116,21 @@ def __init__( self, video_args: VideoArgs, sampler_args: SamplerArgs, - decoder_args: Union[None, DecoderArgs] = None, + decoder_args: DecoderArgs | None = None, ) -> None: super().__init__() self.video_args = video_args self.sampler_args = sampler_args self.decoder_args = DecoderArgs() if decoder_args is None else decoder_args - def forward(self, video_data: Tensor) -> Union[List[Any]]: + def forward(self, video_data: Tensor) -> list[Any]: """Sample video clips from the video data Args: video_data (`Tensor`): The video data Return - clips (` List[List[Tensor]]`): List of clips, where each clip is a list of Tensors, each tensor represents a frame image. + clips (` list[list[Tensor]]`): List of clips, where each clip is a list of Tensors, each tensor represents a frame image. """ @@ -151,7 +150,7 @@ def forward(self, video_data: Tensor) -> Union[List[Any]]: num_threads=self.decoder_args.num_threads, ) - clips: List[Any] = [] + clips: list[Any] = [] # Cast sampler args to be time based or index based if isinstance(self.sampler_args, TimeBasedSamplerArgs): time_based_sampler_args = self.sampler_args @@ -179,8 +178,8 @@ def _get_clips_for_index_based_sampling( self, video_decoder: Tensor, index_based_sampler_args: IndexBasedSamplerArgs, - metadata_json: Dict[str, Any], - ) -> List[Tensor]: + metadata_json: dict[str, Any], + ) -> list[Tensor]: """Get clips for index based sampling, the sampling is done in 3 steps: 1. Compute clip_start_idxs based on the sampler type and the sampler args; 2. For each clip, given clip_start_idx, video_frame_dilation, frames_per_clip, get indexes for all frames @@ -189,10 +188,10 @@ def _get_clips_for_index_based_sampling( Args: video_decoder (`Tensor`): The video decoder index_based_sampler_args (`IndexBasedSamplerArgs`): The index based sampler args - metadata_json (`Dict[str, Any]`): The metadata of the video in json format + metadata_json (`dict[str, Any]`): The metadata of the video in json format Returns: - clips (` List[Tensor]`): List of clips, where each clip is a Tensor represents list of frames, Tensor shape default is NCHW. + clips (` list[Tensor]`): List of clips, where each clip is a Tensor represents list of frames, Tensor shape default is NCHW. """ sample_start_index = max(0, index_based_sampler_args.sample_start_index) @@ -226,7 +225,7 @@ def _get_clips_for_index_based_sampling( clip_start_idx + i * index_based_sampler_args.video_frame_dilation for i in range(index_based_sampler_args.frames_per_clip) ] - # Need torch.stack to convert List[Tensor[int]] into 1D Tensor[int] + # Need torch.stack to convert list[Tensor[int]] into 1D Tensor[int] batch_indexes = torch.stack(batch_indexes) frames, *_ = get_frames_at_indices( video_decoder, @@ -238,18 +237,18 @@ def _get_clips_for_index_based_sampling( def _get_start_seconds( self, - metadata_json: Dict[str, Any], + metadata_json: dict[str, Any], time_based_sampler_args: TimeBasedSamplerArgs, - ) -> List[float]: + ) -> list[float]: """Get start seconds for each clip. Given different sampler type, the API returns different clip start seconds. Args: - metadata_json (`Dict[str, Any]`): The metadata of the video in json format + metadata_json (`dict[str, Any]`): The metadata of the video in json format time_based_sampler_args: (`TimeBasedSamplerArgs`): The time based sampler args Returns: - (`List[float]`): List of the sampled clip start position in seconds + (`list[float]`): List of the sampled clip start position in seconds """ video_duration_in_seconds = metadata_json["durationSecondsFromHeader"] @@ -277,7 +276,7 @@ def _get_start_seconds( "Cannot get clips because video duration is shorter than the clip duration!" ) sampler_type = time_based_sampler_args.sampler_type - clip_starts_in_seconds: List[float] = [] + clip_starts_in_seconds: list[float] = [] sample_start_second = max( time_based_sampler_args.sample_start_second, beginStreamSecondsFromContent, @@ -306,7 +305,7 @@ def _get_start_seconds( def _get_clip_with_start_second( self, start_second: float, video_decoder: Tensor, video_frame_dilation: int - ) -> List[Tensor]: + ) -> list[Tensor]: """Get clip with start second. Args: @@ -315,7 +314,7 @@ def _get_clip_with_start_second( `video_frame_dilation` (`int`): The video frame dilation, by default it's 1. Returns: - `clip` (`List[Tensor]`): clip is list of frame tensor. Dimension of each frame tensor is user specified, by default it's HWC. + `clip` (`list[Tensor]`): clip is list of frame tensor. Dimension of each frame tensor is user specified, by default it's HWC. """ seek_to_pts(video_decoder, start_second) frames_needed_per_clip = ( @@ -332,7 +331,7 @@ def _get_clip_with_start_second( def _compute_frame_width_height( self, ori_width: int, ori_height: int - ) -> Tuple[int, int]: + ) -> tuple[int, int]: """Compute output frame width and height desired_width, desired_height, desired_min_dimension, desired_max_dimension, (`int`): Together decide the size of the decoded video clips. (Default: `0`). Note that the desired_width/desired_height parameters are mutually exclusive with desired_min_dimension/desired_max_dimension parameters. @@ -364,7 +363,7 @@ def _compute_frame_width_height( ori_height (`int`): Original height of the video Returns: - (`Tuple[int, int]`): output frame width and height + (`tuple[int, int]`): output frame width and height """ width_height_ratio = ori_width / ori_height height_width_ratio = ori_height / ori_width diff --git a/src/torchcodec/decoders/__init__.py b/src/torchcodec/decoders/__init__.py index 980ba98a9..0ab4aedef 100644 --- a/src/torchcodec/decoders/__init__.py +++ b/src/torchcodec/decoders/__init__.py @@ -6,7 +6,11 @@ from .._core import AudioStreamMetadata, VideoStreamMetadata from ._audio_decoder import AudioDecoder # noqa -from ._decoder_utils import set_cuda_backend # noqa -from ._video_decoder import VideoDecoder # noqa +from ._decoder_utils import ( # noqa + get_nvdec_cache_capacity, + set_cuda_backend, + set_nvdec_cache_capacity, +) +from ._video_decoder import CpuFallbackStatus, VideoDecoder # noqa SimpleVideoDecoder = VideoDecoder diff --git a/src/torchcodec/decoders/_audio_decoder.py b/src/torchcodec/decoders/_audio_decoder.py index d1e42c196..f4fbdca65 100644 --- a/src/torchcodec/decoders/_audio_decoder.py +++ b/src/torchcodec/decoders/_audio_decoder.py @@ -4,18 +4,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + import io from pathlib import Path -from typing import Optional, Union import torch from torch import Tensor from torchcodec import _core as core, AudioSamples -from torchcodec.decoders._decoder_utils import ( - create_decoder, - ERROR_REPORTING_INSTRUCTIONS, -) +from torchcodec._core._decoder_utils import create_audio_decoder class AudioDecoder: @@ -54,36 +51,26 @@ class AudioDecoder: def __init__( self, - source: Union[str, Path, io.RawIOBase, io.BufferedReader, bytes, Tensor], + source: str | Path | io.RawIOBase | io.BufferedReader | bytes | Tensor, *, - stream_index: Optional[int] = None, - sample_rate: Optional[int] = None, - num_channels: Optional[int] = None, + stream_index: int | None = None, + sample_rate: int | None = None, + num_channels: int | None = None, ): torch._C._log_api_usage_once("torchcodec.decoders.AudioDecoder") - self._decoder = create_decoder(source=source, seek_mode="approximate") - core.add_audio_stream( + ( self._decoder, + self.stream_index, + self.metadata, + ) = create_audio_decoder( + source=source, + seek_mode="approximate", stream_index=stream_index, sample_rate=sample_rate, num_channels=num_channels, ) - container_metadata = core.get_container_metadata(self._decoder) - self.stream_index = ( - container_metadata.best_audio_stream_index - if stream_index is None - else stream_index - ) - if self.stream_index is None: - raise ValueError( - "The best audio stream is unknown and there is no specified stream. " - + ERROR_REPORTING_INSTRUCTIONS - ) - self.metadata = container_metadata.streams[self.stream_index] - assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy - self._desired_sample_rate = ( sample_rate if sample_rate is not None else self.metadata.sample_rate ) @@ -100,7 +87,7 @@ def get_all_samples(self) -> AudioSamples: return self.get_samples_played_in_range() def get_samples_played_in_range( - self, start_seconds: float = 0.0, stop_seconds: Optional[float] = None + self, start_seconds: float = 0.0, stop_seconds: float | None = None ) -> AudioSamples: """Returns audio samples in the given range. diff --git a/src/torchcodec/decoders/_decoder_utils.py b/src/torchcodec/decoders/_decoder_utils.py index 2619acd24..30c6a26ff 100644 --- a/src/torchcodec/decoders/_decoder_utils.py +++ b/src/torchcodec/decoders/_decoder_utils.py @@ -4,54 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + import contextvars -import io +from collections.abc import Generator from contextlib import contextmanager -from pathlib import Path - -from typing import Generator, Union - -from torch import Tensor -from torchcodec import _core as core - -ERROR_REPORTING_INSTRUCTIONS = """ -This should never happen. Please report an issue following the steps in -https://github.com/pytorch/torchcodec/issues/new?assignees=&labels=&projects=&template=bug-report.yml. -""" - - -def create_decoder( - *, - source: Union[str, Path, io.RawIOBase, io.BufferedReader, bytes, Tensor], - seek_mode: str, -) -> Tensor: - if isinstance(source, str): - return core.create_from_file(source, seek_mode) - elif isinstance(source, Path): - return core.create_from_file(str(source), seek_mode) - elif isinstance(source, io.RawIOBase) or isinstance(source, io.BufferedReader): - return core.create_from_file_like(source, seek_mode) - elif isinstance(source, bytes): - return core.create_from_bytes(source, seek_mode) - elif isinstance(source, Tensor): - return core.create_from_tensor(source, seek_mode) - elif isinstance(source, io.TextIOBase): - raise TypeError( - "source is for reading text, likely from open(..., 'r'). Try with 'rb' for binary reading?" - ) - elif hasattr(source, "read") and hasattr(source, "seek"): - # This check must be after checking for text-based reading. Also placing - # it last in general to be defensive: hasattr is a blunt instrument. We - # could use the inspect module to check for methods with the right - # signature. - return core.create_from_file_like(source, seek_mode) - raise TypeError( - f"Unknown source type: {type(source)}. " - "Supported types are str, Path, bytes, Tensor and file-like objects with " - "read(self, size: int) -> bytes and " - "seek(self, offset: int, whence: int) -> int methods." - ) +from torchcodec import _core # Thread-local and async-safe storage for the current CUDA backend @@ -110,3 +68,37 @@ def set_cuda_backend(backend: str) -> Generator[None, None, None]: def _get_cuda_backend() -> str: return _CUDA_BACKEND.get() + + +def set_nvdec_cache_capacity(capacity: int) -> None: + """Set the maximum number of NVDEC decoders that can be cached (per GPU). + + The NVDEC decoder cache stores hardware decoders for reuse, avoiding the + overhead of creating and destructing new decoders for subsequent video + decoding operations on the same GPU. This function sets the capacity of the + cache, i.e. the maximum number of decoders that can be cached per device. + The default capacity is 20 decoders per device. If the cache contains more + decoders than the target ``capacity``, excess decoders will be evicted + using a least-recently-used policy. + + Generally, a decoder can be re-used from the cache if it matches the same + codec and frame dimensions. + + See also :func:`~torchcodec.decoders.get_nvdec_cache_capacity`. + + Args: + capacity (int): The maximum number of NVDEC decoders that can be cached + per GPU device. Must be non-negative. Setting to 0 disables caching. + """ + _core.set_nvdec_cache_capacity(capacity) + + +def get_nvdec_cache_capacity() -> int: + """Get the capacity of the per-device NVDEC decoder cache. + + See also :func:`~torchcodec.decoders.set_nvdec_cache_capacity`. + + Returns: + int: The maximum number of NVDEC decoders that can be cached per GPU device. + """ + return _core.get_nvdec_cache_capacity() diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 6b524f119..bc3a2fc81 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -4,21 +4,71 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + import io import json import numbers +from collections.abc import Sequence +from dataclasses import dataclass, field from pathlib import Path -from typing import Literal, Optional, Tuple, Union +from typing import Literal import torch -from torch import device as torch_device, Tensor - +from torch import device as torch_device, nn, Tensor from torchcodec import _core as core, Frame, FrameBatch -from torchcodec.decoders._decoder_utils import ( - _get_cuda_backend, - create_decoder, - ERROR_REPORTING_INSTRUCTIONS, -) +from torchcodec._core._decoder_utils import create_video_decoder +from torchcodec.decoders._decoder_utils import _get_cuda_backend +from torchcodec.transforms import DecoderTransform + + +@dataclass +class CpuFallbackStatus: + """Information about CPU fallback status. + + This class tracks whether the decoder fell back to CPU decoding. + Users should not instantiate this class directly; instead, access it + via the :attr:`VideoDecoder.cpu_fallback` attribute. + + Usage: + + - Use ``str(cpu_fallback_status)`` or ``print(cpu_fallback_status)`` to see the cpu fallback status + - Use ``if cpu_fallback_status:`` to check if any fallback occurred + """ + + status_known: bool = False + """Whether the fallback status has been determined. + For the Beta CUDA backend (see :func:`~torchcodec.decoders.set_cuda_backend`), + this is always ``True`` immediately after decoder creation. + For the FFmpeg CUDA backend, this becomes ``True`` after decoding + the first frame.""" + _nvcuvid_unavailable: bool = field(default=False, init=False) + _video_not_supported: bool = field(default=False, init=False) + _is_fallback: bool = field(default=False, init=False) + _backend: str = field(default="", init=False) + + def __bool__(self): + """Returns True if fallback occurred.""" + return self.status_known and self._is_fallback + + def __str__(self): + """Returns a human-readable string representation of the cpu fallback status.""" + if not self.status_known: + return f"[{self._backend}] Fallback status: Unknown" + + reasons = [] + if self._nvcuvid_unavailable: + reasons.append("NVcuvid unavailable") + elif self._video_not_supported: + reasons.append("Video not supported") + elif self._is_fallback: + reasons.append("Unknown reason - try the Beta interface to know more!") + + if reasons: + return ( + f"[{self._backend}] Fallback status: Falling back due to: " + + ", ".join(reasons) + ) + return f"[{self._backend}] Fallback status: No fallback required" class VideoDecoder: @@ -49,13 +99,15 @@ class VideoDecoder: cheap no-copy operation that allows these frames to be transformed using the `torchvision transforms `_. - num_ffmpeg_threads (int, optional): The number of threads to use for decoding. + num_ffmpeg_threads (int, optional): The number of threads to use for CPU decoding. + This has no effect when using GPU decoding. Use 1 for single-threaded decoding which may be best if you are running multiple instances of ``VideoDecoder`` in parallel. Use a higher number for multi-threaded decoding which is best if you are running a single instance of ``VideoDecoder``. Passing 0 lets FFmpeg decide on the number of threads. Default: 1. - device (str or torch.device, optional): The device to use for decoding. Default: "cpu". + device (str or torch.device, optional): The device to use for decoding. + If ``None`` (default), uses the current default device. If you pass a CUDA device, we recommend trying the "beta" CUDA backend which is faster! See :func:`~torchcodec.decoders.set_cuda_backend`. seek_mode (str, optional): Determines if frame access will be "exact" or @@ -66,6 +118,11 @@ class VideoDecoder: probably is. Default: "exact". Read more about this parameter in: :ref:`sphx_glr_generated_examples_decoding_approximate_mode.py` + transforms (sequence of transform objects, optional): Sequence of transforms to be + applied to the decoded frames by the decoder itself, in order. Accepts both + :class:`~torchcodec.transforms.DecoderTransform` and + :class:`~torchvision.transforms.v2.Transform` + objects. Read more about this parameter in: TODO_DECODER_TRANSFORMS_TUTORIAL. custom_frame_mappings (str, bytes, or file-like object, optional): Mapping of frames to their metadata, typically generated via ffprobe. This enables accurate frame seeking without requiring a full video scan. @@ -93,20 +150,25 @@ class VideoDecoder: stream_index (int): The stream index that this decoder is retrieving frames from. If a stream index was provided at initialization, this is the same value. If it was left unspecified, this is the :term:`best stream`. + cpu_fallback (CpuFallbackStatus): Information about whether the decoder fell back to CPU + decoding. Use ``bool(cpu_fallback)`` to check if fallback occurred, or + ``str(cpu_fallback)`` to get a human-readable status message. The status is only + determined after at least one frame has been decoded. """ def __init__( self, - source: Union[str, Path, io.RawIOBase, io.BufferedReader, bytes, Tensor], + source: str | Path | io.RawIOBase | io.BufferedReader | bytes | Tensor, *, - stream_index: Optional[int] = None, + stream_index: int | None = None, dimension_order: Literal["NCHW", "NHWC"] = "NCHW", num_ffmpeg_threads: int = 1, device: Optional[Union[str, "torch_device"]] = "cpu", seek_mode: Literal["exact", "approximate"] = "exact", - custom_frame_mappings: Optional[ - Union[str, bytes, io.RawIOBase, io.BufferedReader] - ] = None, + transforms: Sequence[DecoderTransform | nn.Module] | None = None, + custom_frame_mappings: ( + str | bytes | io.RawIOBase | io.BufferedReader | None + ) = None, ): torch._C._log_api_usage_once("torchcodec.decoders.VideoDecoder") allowed_seek_modes = ("exact", "approximate") @@ -131,8 +193,6 @@ def __init__( custom_frame_mappings ) - self._decoder = create_decoder(source=source, seek_mode=seek_mode) - allowed_dimension_orders = ("NCHW", "NHWC") if dimension_order not in allowed_dimension_orders: raise ValueError( @@ -155,8 +215,11 @@ def __init__( raise NotImplementedError(f"{device} is not supported yet") device_variant = _get_cuda_backend() - - core.add_video_stream( + if device is None: + device = str(torch.get_default_device()) + elif isinstance(device, torch_device): + device = str(device) + ( self._decoder, num_threads=num_ffmpeg_threads, dimension_order=dimension_order, @@ -167,19 +230,53 @@ def __init__( custom_frame_mappings=custom_frame_mappings_data, ) - ( - self.metadata, - self.stream_index, - self._begin_stream_seconds, - self._end_stream_seconds, - self._num_frames, - ) = _get_and_validate_stream_metadata( - decoder=self._decoder, stream_index=stream_index - ) + assert self.metadata.begin_stream_seconds is not None # mypy. + assert self.metadata.end_stream_seconds is not None # mypy. + assert self.metadata.num_frames is not None # mypy. + + self._begin_stream_seconds = self.metadata.begin_stream_seconds + self._end_stream_seconds = self.metadata.end_stream_seconds + self._num_frames = self.metadata.num_frames + + self._cpu_fallback = CpuFallbackStatus() + if device.startswith("cuda"): + if device_variant == "beta": + self._cpu_fallback._backend = "Beta CUDA" + else: + self._cpu_fallback._backend = "FFmpeg CUDA" + else: + self._cpu_fallback._backend = "CPU" def __len__(self) -> int: return self._num_frames + @property + def cpu_fallback(self) -> CpuFallbackStatus: + # We only query the CPU fallback info if status is unknown. That happens + # either when: + # - this @property has never been called before + # - no frame has been decoded yet on the FFmpeg interface. + # Note that for the beta interface, we're able to know the fallback status + # right when the VideoDecoder is instantiated, but the status_known + # attribute is initialized to False. + if not self._cpu_fallback.status_known: + backend_details = core._get_backend_details(self._decoder) + + if "status unknown" not in backend_details: + self._cpu_fallback.status_known = True + + if "CPU fallback" in backend_details: + self._cpu_fallback._is_fallback = True + if self._cpu_fallback._backend == "Beta CUDA": + # Only the beta interface can provide details. + # if it's not that nvcuvid is missing, it must be video-specific + if "NVCUVID not available" in backend_details: + self._cpu_fallback._nvcuvid_unavailable = True + else: + self._cpu_fallback._video_not_supported = True + + return self._cpu_fallback + def _getitem_int(self, key: int) -> Tensor: assert isinstance(key, int) @@ -198,7 +295,7 @@ def _getitem_slice(self, key: slice) -> Tensor: ) return frame_data - def __getitem__(self, key: Union[numbers.Integral, slice]) -> Tensor: + def __getitem__(self, key: numbers.Integral | slice) -> Tensor: """Return frame or frames as tensors, at the given index or range. .. note:: @@ -255,7 +352,7 @@ def get_frame_at(self, index: int) -> Frame: duration_seconds=duration_seconds.item(), ) - def get_frames_at(self, indices: Union[torch.Tensor, list[int]]) -> FrameBatch: + def get_frames_at(self, indices: torch.Tensor | list[int]) -> FrameBatch: """Return frames at the given indices. Args: @@ -335,9 +432,7 @@ def get_frame_played_at(self, seconds: float) -> Frame: duration_seconds=duration_seconds.item(), ) - def get_frames_played_at( - self, seconds: Union[torch.Tensor, list[float]] - ) -> FrameBatch: + def get_frames_played_at(self, seconds: torch.Tensor | list[float]) -> FrameBatch: """Return frames played at the given timestamps in seconds. Args: @@ -360,7 +455,7 @@ def get_frames_played_at( ) def get_frames_played_in_range( - self, start_seconds: float, stop_seconds: float + self, start_seconds: float, stop_seconds: float, fps: float | None = None ) -> FrameBatch: """Returns multiple frames in the given range. @@ -369,23 +464,26 @@ def get_frames_played_in_range( range. Args: - start_seconds (float): Time, in seconds, of the start of the - range. - stop_seconds (float): Time, in seconds, of the end of the - range. As a half open range, the end is excluded. + start_seconds (float): Time, in seconds, of the start of the range. + stop_seconds (float): Time, in seconds, of the end of the range. + As a half open range, the end is excluded. + fps (float, optional): If specified, resample output to this frame + rate by duplicating or dropping frames as necessary. If None + (default), returns frames at the source video's frame rate. Returns: FrameBatch: The frames within the specified range. """ if not start_seconds <= stop_seconds: raise ValueError( - f"Invalid start seconds: {start_seconds}. It must be less than or equal to stop seconds ({stop_seconds})." + f"Invalid start seconds: {start_seconds}. " + f"It must be less than or equal to stop seconds ({stop_seconds})." ) if not self._begin_stream_seconds <= start_seconds < self._end_stream_seconds: raise ValueError( f"Invalid start seconds: {start_seconds}. " f"It must be greater than or equal to {self._begin_stream_seconds} " - f"and less than or equal to {self._end_stream_seconds}." + f"and less than {self._end_stream_seconds}." ) if not stop_seconds <= self._end_stream_seconds: raise ValueError( @@ -396,59 +494,30 @@ def get_frames_played_in_range( self._decoder, start_seconds=start_seconds, stop_seconds=stop_seconds, + fps=fps, ) return FrameBatch(*frames) + def get_all_frames(self, fps: float | None = None) -> FrameBatch: + """Returns all frames in the video. -def _get_and_validate_stream_metadata( - *, - decoder: Tensor, - stream_index: Optional[int] = None, -) -> Tuple[core._metadata.VideoStreamMetadata, int, float, float, int]: - - container_metadata = core.get_container_metadata(decoder) - - if stream_index is None: - if (stream_index := container_metadata.best_video_stream_index) is None: - raise ValueError( - "The best video stream is unknown and there is no specified stream. " - + ERROR_REPORTING_INSTRUCTIONS - ) - - metadata = container_metadata.streams[stream_index] - assert isinstance(metadata, core._metadata.VideoStreamMetadata) # mypy - - if metadata.begin_stream_seconds is None: - raise ValueError( - "The minimum pts value in seconds is unknown. " - + ERROR_REPORTING_INSTRUCTIONS - ) - begin_stream_seconds = metadata.begin_stream_seconds - - if metadata.end_stream_seconds is None: - raise ValueError( - "The maximum pts value in seconds is unknown. " - + ERROR_REPORTING_INSTRUCTIONS - ) - end_stream_seconds = metadata.end_stream_seconds + Args: + fps (float, optional): If specified, resample output to this frame + rate by duplicating or dropping frames as necessary. If None + (default), returns frames at the source video's frame rate. - if metadata.num_frames is None: - raise ValueError( - "The number of frames is unknown. " + ERROR_REPORTING_INSTRUCTIONS + Returns: + FrameBatch: All frames in the video. + """ + return self.get_frames_played_in_range( + start_seconds=self._begin_stream_seconds, + stop_seconds=self._end_stream_seconds, + fps=fps, ) - num_frames = metadata.num_frames - - return ( - metadata, - stream_index, - begin_stream_seconds, - end_stream_seconds, - num_frames, - ) def _read_custom_frame_mappings( - custom_frame_mappings: Union[str, bytes, io.RawIOBase, io.BufferedReader] + custom_frame_mappings: str | bytes | io.RawIOBase | io.BufferedReader, ) -> tuple[Tensor, Tensor, Tensor]: """Parse custom frame mappings from JSON data and extract frame metadata. diff --git a/src/torchcodec/decoders/_wav_decoder.py b/src/torchcodec/decoders/_wav_decoder.py new file mode 100644 index 000000000..63e49e2c6 --- /dev/null +++ b/src/torchcodec/decoders/_wav_decoder.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import json +from pathlib import Path + +import torch +from torchcodec import _core +from torchcodec._core._decoder_utils import create_wav_decoder +from torchcodec._core._metadata import AudioStreamMetadata + + +class WavDecoder: + # TODO: Docstrings + def __init__(self, source: str | Path): + torch._C._log_api_usage_once("torchcodec.decoders.WavDecoder") + + self._decoder = create_wav_decoder(source) + self._source = source + self.stream_index = 0 + + metadata_json = _core.get_wav_metadata_from_decoder(self._decoder) + metadata_dict = json.loads(metadata_json) + + self.metadata = AudioStreamMetadata( + sample_rate=metadata_dict["sampleRate"], + num_channels=metadata_dict["numChannels"], + sample_format=metadata_dict["sampleFormat"], + duration_seconds=metadata_dict["durationSeconds"], + stream_index=metadata_dict["streamIndex"], + codec=metadata_dict["codec"], + bit_rate=metadata_dict["bitRate"], + duration_seconds_from_header=metadata_dict["durationSecondsFromHeader"], + begin_stream_seconds=metadata_dict["beginStreamSeconds"], + begin_stream_seconds_from_header=None, # WAV format lacks stream start time metadata + ) + + def get_all_samples(self): + return _core.get_wav_all_samples(self._decoder) diff --git a/src/torchcodec/encoders/_audio_encoder.py b/src/torchcodec/encoders/_audio_encoder.py index fc8879cfa..769a98acc 100644 --- a/src/torchcodec/encoders/_audio_encoder.py +++ b/src/torchcodec/encoders/_audio_encoder.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import Optional, Union import torch from torch import Tensor @@ -44,11 +43,11 @@ def __init__(self, samples: Tensor, *, sample_rate: int): def to_file( self, - dest: Union[str, Path], + dest: str | Path, *, - bit_rate: Optional[int] = None, - num_channels: Optional[int] = None, - sample_rate: Optional[int] = None, + bit_rate: int | None = None, + num_channels: int | None = None, + sample_rate: int | None = None, ) -> None: """Encode samples into a file. @@ -79,9 +78,9 @@ def to_tensor( self, format: str, *, - bit_rate: Optional[int] = None, - num_channels: Optional[int] = None, - sample_rate: Optional[int] = None, + bit_rate: int | None = None, + num_channels: int | None = None, + sample_rate: int | None = None, ) -> Tensor: """Encode samples into raw bytes, as a 1D uint8 Tensor. @@ -115,9 +114,9 @@ def to_file_like( file_like, format: str, *, - bit_rate: Optional[int] = None, - num_channels: Optional[int] = None, - sample_rate: Optional[int] = None, + bit_rate: int | None = None, + num_channels: int | None = None, + sample_rate: int | None = None, ) -> None: """Encode samples into a file-like object. diff --git a/src/torchcodec/encoders/_video_encoder.py b/src/torchcodec/encoders/_video_encoder.py index f6a725278..5d98b159f 100644 --- a/src/torchcodec/encoders/_video_encoder.py +++ b/src/torchcodec/encoders/_video_encoder.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Union +from typing import Any import torch from torch import Tensor @@ -8,17 +8,19 @@ class VideoEncoder: - """A video encoder. + """A video encoder on CPU or CUDA.. Args: frames (``torch.Tensor``): The frames to encode. This must be a 4D tensor of shape ``(N, C, H, W)`` where N is the number of frames, C is 3 channels (RGB), H is height, and W is width. Values must be uint8 in the range ``[0, 255]``. - frame_rate (int): The frame rate of the **input** ``frames``. Also defines the encoded **output** frame rate. + The tensor can be on CPU or CUDA. The device of the tensor + determines which encoder is used (CPU or GPU). + frame_rate (float): The frame rate of the **input** ``frames``. Also defines the encoded **output** frame rate. """ - def __init__(self, frames: Tensor, *, frame_rate: int): + def __init__(self, frames: Tensor, *, frame_rate: float): torch._C._log_api_usage_once("torchcodec.encoders.VideoEncoder") if not isinstance(frames, Tensor): raise ValueError(f"Expected frames to be a Tensor, got {type(frames) = }.") @@ -34,7 +36,13 @@ def __init__(self, frames: Tensor, *, frame_rate: int): def to_file( self, - dest: Union[str, Path], + dest: str | Path, + *, + codec: str | None = None, + pixel_format: str | None = None, + crf: int | float | None = None, + preset: str | int | None = None, + extra_options: dict[str, Any] | None = None, ) -> None: """Encode frames into a file. @@ -42,36 +50,104 @@ def to_file( dest (str or ``pathlib.Path``): The path to the output file, e.g. ``video.mp4``. The extension of the file determines the video container format. + codec (str, optional): The codec to use for encoding (e.g., "libx264", + "h264"). If not specified, the default codec + for the container format will be used. + See :ref:`codec_selection` for details. + pixel_format (str, optional): The pixel format for encoding (e.g., + "yuv420p", "yuv444p"). If not specified, uses codec's default format. + Must be left as ``None`` when encoding CUDA tensors. + See :ref:`pixel_format` for details. + crf (int or float, optional): Constant Rate Factor for encoding quality. Lower values + mean better quality. Valid range depends on the encoder (e.g. 0-51 for libx264). + Defaults to None (which will use encoder's default). + See :ref:`crf` for details. + preset (str or int, optional): Encoder option that controls the tradeoff between + encoding encoding speed and compression (output size). Valid on the encoder (commonly + a string: "fast", "medium", "slow"). Defaults to None + (which will use encoder's default). + See :ref:`preset` for details. + extra_options (dict[str, Any], optional): A dictionary of additional + encoder options to pass, e.g. ``{"qp": 5, "tune": "film"}``. + See :ref:`extra_options` for details. """ + preset = str(preset) if isinstance(preset, int) else preset _core.encode_video_to_file( frames=self._frames, frame_rate=self._frame_rate, filename=str(dest), + codec=codec, + pixel_format=pixel_format, + crf=crf, + preset=preset, + extra_options=[ + str(x) for k, v in (extra_options or {}).items() for x in (k, v) + ], ) def to_tensor( self, format: str, + *, + codec: str | None = None, + pixel_format: str | None = None, + crf: int | float | None = None, + preset: str | int | None = None, + extra_options: dict[str, Any] | None = None, ) -> Tensor: """Encode frames into raw bytes, as a 1D uint8 Tensor. Args: format (str): The container format of the encoded frames, e.g. "mp4", "mov", - "mkv", "avi", "webm", "flv", or "gif" + "mkv", "avi", "webm", "flv", etc. + codec (str, optional): The codec to use for encoding (e.g., "libx264", + "h264"). If not specified, the default codec + for the container format will be used. + See :ref:`codec_selection` for details. + pixel_format (str, optional): The pixel format to encode frames into (e.g., + "yuv420p", "yuv444p"). If not specified, uses codec's default format. + Must be left as ``None`` when encoding CUDA tensors. + See :ref:`pixel_format` for details. + crf (int or float, optional): Constant Rate Factor for encoding quality. Lower values + mean better quality. Valid range depends on the encoder (e.g. 0-51 for libx264). + Defaults to None (which will use encoder's default). + See :ref:`crf` for details. + preset (str or int, optional): Encoder option that controls the tradeoff between + encoding encoding speed and compression (output size). Valid on the encoder (commonly + a string: "fast", "medium", "slow"). Defaults to None + (which will use encoder's default). + See :ref:`preset` for details. + extra_options (dict[str, Any], optional): A dictionary of additional + encoder options to pass, e.g. ``{"qp": 5, "tune": "film"}``. + See :ref:`extra_options` for details. Returns: - Tensor: The raw encoded bytes as 4D uint8 Tensor. + Tensor: The raw encoded bytes as 1D uint8 Tensor on CPU regardless of the device of the input frames. """ + preset_value = str(preset) if isinstance(preset, int) else preset return _core.encode_video_to_tensor( frames=self._frames, frame_rate=self._frame_rate, format=format, + codec=codec, + pixel_format=pixel_format, + crf=crf, + preset=preset_value, + extra_options=[ + str(x) for k, v in (extra_options or {}).items() for x in (k, v) + ], ) def to_file_like( self, file_like, format: str, + *, + codec: str | None = None, + pixel_format: str | None = None, + crf: int | float | None = None, + preset: str | int | None = None, + extra_options: dict[str, Any] | None = None, ) -> None: """Encode frames into a file-like object. @@ -82,11 +158,39 @@ def to_file_like( ``write(data: bytes) -> int`` and ``seek(offset: int, whence: int = 0) -> int``. format (str): The container format of the encoded frames, e.g. "mp4", "mov", - "mkv", "avi", "webm", "flv", or "gif". + "mkv", "avi", "webm", "flv", etc. + codec (str, optional): The codec to use for encoding (e.g., "libx264", + "h264"). If not specified, the default codec + for the container format will be used. + See :ref:`codec_selection` for details. + pixel_format (str, optional): The pixel format for encoding (e.g., + "yuv420p", "yuv444p"). If not specified, uses codec's default format. + Must be left as ``None`` when encoding CUDA tensors. + See :ref:`pixel_format` for details. + crf (int or float, optional): Constant Rate Factor for encoding quality. Lower values + mean better quality. Valid range depends on the encoder (e.g. 0-51 for libx264). + Defaults to None (which will use encoder's default). + See :ref:`crf` for details. + preset (str or int, optional): Encoder option that controls the tradeoff between + encoding encoding speed and compression (output size). Valid on the encoder (commonly + a string: "fast", "medium", "slow"). Defaults to None + (which will use encoder's default). + See :ref:`preset` for details. + extra_options (dict[str, Any], optional): A dictionary of additional + encoder options to pass, e.g. ``{"qp": 5, "tune": "film"}``. + See :ref:`extra_options` for details. """ + preset = str(preset) if isinstance(preset, int) else preset _core.encode_video_to_file_like( frames=self._frames, frame_rate=self._frame_rate, format=format, file_like=file_like, + codec=codec, + pixel_format=pixel_format, + crf=crf, + preset=preset, + extra_options=[ + str(x) for k, v in (extra_options or {}).items() for x in (k, v) + ], ) diff --git a/src/torchcodec/samplers/_common.py b/src/torchcodec/samplers/_common.py index a129a4483..0739c5538 100644 --- a/src/torchcodec/samplers/_common.py +++ b/src/torchcodec/samplers/_common.py @@ -1,8 +1,8 @@ -from typing import Callable, Union +from collections.abc import Callable from torchcodec import FrameBatch -_LIST_OF_INT_OR_FLOAT = Union[list[int], list[float]] +_LIST_OF_INT_OR_FLOAT = list[int] | list[float] def _repeat_last_policy( diff --git a/src/torchcodec/samplers/_index_based.py b/src/torchcodec/samplers/_index_based.py index 2620c171c..5e0df7de7 100644 --- a/src/torchcodec/samplers/_index_based.py +++ b/src/torchcodec/samplers/_index_based.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional +from typing import Literal import torch @@ -125,7 +125,7 @@ def _generic_index_based_sampler( num_frames_per_clip: int, num_indices_between_frames: int, sampling_range_start: int, - sampling_range_end: Optional[int], # interval is [start, end). + sampling_range_end: int | None, # interval is [start, end). # Important note: sampling_range_end defines the upper bound of where a clip # can *start*, not where a clip can end. policy: Literal["repeat_last", "wrap", "error"], @@ -192,7 +192,7 @@ def clips_at_random_indices( num_frames_per_clip: int = 1, num_indices_between_frames: int = 1, sampling_range_start: int = 0, - sampling_range_end: Optional[int] = None, # interval is [start, end). + sampling_range_end: int | None = None, # interval is [start, end). policy: Literal["repeat_last", "wrap", "error"] = "repeat_last", ) -> FrameBatch: # See docstring below @@ -216,7 +216,7 @@ def clips_at_regular_indices( num_frames_per_clip: int = 1, num_indices_between_frames: int = 1, sampling_range_start: int = 0, - sampling_range_end: Optional[int] = None, # interval is [start, end). + sampling_range_end: int | None = None, # interval is [start, end). policy: Literal["repeat_last", "wrap", "error"] = "repeat_last", ) -> FrameBatch: # See docstring below diff --git a/src/torchcodec/samplers/_time_based.py b/src/torchcodec/samplers/_time_based.py index d58114121..beb49addd 100644 --- a/src/torchcodec/samplers/_time_based.py +++ b/src/torchcodec/samplers/_time_based.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional +from typing import Literal import torch @@ -151,13 +151,13 @@ def _generic_time_based_sampler( kind: Literal["random", "regular"], decoder, *, - num_clips: Optional[int], # mutually exclusive with seconds_between_clip_starts - seconds_between_clip_starts: Optional[float], + num_clips: int | None, # mutually exclusive with seconds_between_clip_starts + seconds_between_clip_starts: float | None, num_frames_per_clip: int, - seconds_between_frames: Optional[float], + seconds_between_frames: float | None, # None means "begining", which may not always be 0 - sampling_range_start: Optional[float], - sampling_range_end: Optional[float], # interval is [start, end). + sampling_range_start: float | None, + sampling_range_end: float | None, # interval is [start, end). policy: Literal["repeat_last", "wrap", "error"] = "repeat_last", ) -> FrameBatch: # Note: *everywhere*, sampling_range_end denotes the upper bound of where a @@ -192,7 +192,8 @@ def _generic_time_based_sampler( # torch.rand() returns in [0, 1) # which ensures all clip starts are < sampling_range_end clip_start_seconds = ( - torch.rand(num_clips) * sampling_range_width + sampling_range_start + torch.rand(num_clips, dtype=torch.float64) * sampling_range_width + + sampling_range_start ) else: assert seconds_between_clip_starts is not None # appease type-checker @@ -200,6 +201,7 @@ def _generic_time_based_sampler( sampling_range_start, sampling_range_end, # excluded seconds_between_clip_starts, + dtype=torch.float64, ) # As mentioned in the docs, torch.arange may return values # equal to or above `end` because of floating precision errors. @@ -232,10 +234,10 @@ def clips_at_random_timestamps( *, num_clips: int = 1, num_frames_per_clip: int = 1, - seconds_between_frames: Optional[float] = None, + seconds_between_frames: float | None = None, # None means "begining", which may not always be 0 - sampling_range_start: Optional[float] = None, - sampling_range_end: Optional[float] = None, # interval is [start, end). + sampling_range_start: float | None = None, + sampling_range_end: float | None = None, # interval is [start, end). policy: Literal["repeat_last", "wrap", "error"] = "repeat_last", ) -> FrameBatch: # See docstring below @@ -258,10 +260,10 @@ def clips_at_regular_timestamps( *, seconds_between_clip_starts: float, num_frames_per_clip: int = 1, - seconds_between_frames: Optional[float] = None, + seconds_between_frames: float | None = None, # None means "begining", which may not always be 0 - sampling_range_start: Optional[float] = None, - sampling_range_end: Optional[float] = None, # interval is [start, end). + sampling_range_start: float | None = None, + sampling_range_end: float | None = None, # interval is [start, end). policy: Literal["repeat_last", "wrap", "error"] = "repeat_last", ) -> FrameBatch: # See docstring below diff --git a/src/torchcodec/share/cmake/TorchCodec/TorchCodecConfig.cmake b/src/torchcodec/share/cmake/TorchCodec/TorchCodecConfig.cmake new file mode 100644 index 000000000..e199faa69 --- /dev/null +++ b/src/torchcodec/share/cmake/TorchCodec/TorchCodecConfig.cmake @@ -0,0 +1,76 @@ +# FindTorchCodec +# -------------- +# +# Finds the TorchCodec library +# +# This will define the following variables: +# +# TORCHCODEC_FOUND: True if the system has the TorchCodec library +# TORCHCODEC_VARIANTS: list of TorchCodec variants. A variant is a supported +# FFmpeg major version. +# +# and the following imported targets: +# +# torchcodec::ffmpeg${N} +# torchcodec::core${N} +# +# where N is a TorchCodec variant (FFmpeg major version) from +# TORCHCODEC_VARIANTS list. + +include(FindPackageHandleStandardArgs) +include("${CMAKE_CURRENT_LIST_DIR}/ffmpeg_versions.cmake") + +# Assume we are in /share/cmake/TorchCodec/TorchCodecConfig.cmake +get_filename_component(CMAKE_CURRENT_LIST_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH) +get_filename_component(TORCHCODEC_INSTALL_PREFIX "${CMAKE_CURRENT_LIST_DIR}/../../../" ABSOLUTE) + +# Include directories. +set(TORCHCODEC_INCLUDE_DIRS ${TORCHCODEC_INSTALL_PREFIX}/_core) +set(TORCHCODEC_VARIANTS "") + +function(add_torchcodec_target ffmpeg_major_version) + set(target torchcodec::core${ffmpeg_major_version}) + + if (NOT TARGET torchcodec::ffmpeg${ffmpeg_major_version}) + message(FATAL_ERROR "torchcodec::ffmpeg${ffmpeg_major_version} target is not defined") + endif() + + find_library(lib_path torchcodec_core${ffmpeg_major_version} + PATHS "${TORCHCODEC_INSTALL_PREFIX}" NO_CACHE NO_DEFAULT_PATH) + if (NOT lib_path) + message(FATAL_ERROR "torchcodec_core${ffmpeg_major_version} shared library is missing") + endif() + + message("Adding ${target} target") + add_library(${target} SHARED IMPORTED) + add_dependencies(${target} torchcodec::ffmpeg${ffmpeg_major_version}) + set_target_properties(${target} PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES ${TORCHCODEC_INCLUDE_DIRS} + IMPORTED_LOCATION ${lib_path} + ) + + list(APPEND TORCHCODEC_VARIANTS "${ffmpeg_major_version}") + set(TORCHCODEC_VARIANTS "${TORCHCODEC_VARIANTS}" PARENT_SCOPE) +endfunction() + +# If any of the TORCHCODEC_FFMPEG${N}_INSTALL_PREFIX environment variables +# are defined, use them to locate the corresponding FFmpeg and TorchCodec targets. +# Otherwise, fall back to pkg-config to find FFmpeg. +set(use_pkg_config TRUE) +foreach(ffmpeg_major_version IN LISTS TORCHCODEC_SUPPORTED_FFMPEG_VERSIONS) + if (DEFINED ENV{TORCHCODEC_FFMPEG${ffmpeg_major_version}_INSTALL_PREFIX}) + add_ffmpeg_target( + "${ffmpeg_major_version}" + "$ENV{TORCHCODEC_FFMPEG${ffmpeg_major_version}_INSTALL_PREFIX}" + ) + add_torchcodec_target(${ffmpeg_major_version}) + set(use_pkg_config FALSE) + endif() +endforeach() + +if (use_pkg_config) + add_ffmpeg_target_with_pkg_config(ffmpeg_major_version) + add_torchcodec_target(${ffmpeg_major_version}) +endif() + +find_package_handle_standard_args(TorchCodec DEFAULT_MSG TORCHCODEC_VARIANTS) diff --git a/src/torchcodec/share/cmake/TorchCodec/ffmpeg_versions.cmake b/src/torchcodec/share/cmake/TorchCodec/ffmpeg_versions.cmake new file mode 100644 index 000000000..5f4ea87c9 --- /dev/null +++ b/src/torchcodec/share/cmake/TorchCodec/ffmpeg_versions.cmake @@ -0,0 +1,122 @@ +# This file exposes helpers to create and expose FFmpeg targets as torchcodec::ffmpeg${N} +# where N is the FFmpeg major version. + +# List of FFmpeg versions that TorchCodec can support - that's not a list of +# FFmpeg versions available on the current system! +set(TORCHCODEC_SUPPORTED_FFMPEG_VERSIONS "4;5;6;7;8") + +# Create and expose torchcodec::ffmpeg${ffmpeg_major_version} target which can +# then be used as a dependency in other targets. +# prefix is the path to the FFmpeg installation containing the usual `include` +# and `lib` directories. +function(add_ffmpeg_target ffmpeg_major_version prefix) + # Check that given ffmpeg major version is something we support and error out if + # it's not. + list(FIND TORCHCODEC_SUPPORTED_FFMPEG_VERSIONS "${ffmpeg_major_version}" _index) + if (_index LESS 0) + message(FATAL_ERROR "FFmpeg version ${ffmpeg_major_version} is not supported") + endif() + if (NOT DEFINED prefix) + message(FATAL_ERROR "No prefix defined calling add_ffmpeg_target()") + endif() + + # Define library names based on platform and FFmpeg version + if (LINUX) + if (ffmpeg_major_version EQUAL 4) + set(library_file_names libavutil.so.56 libavcodec.so.58 libavformat.so.58 libavdevice.so.58 libavfilter.so.7 libswscale.so.5 libswresample.so.3) + elseif (ffmpeg_major_version EQUAL 5) + set(library_file_names libavutil.so.57 libavcodec.so.59 libavformat.so.59 libavdevice.so.59 libavfilter.so.8 libswscale.so.6 libswresample.so.4) + elseif (ffmpeg_major_version EQUAL 6) + set(library_file_names libavutil.so.58 libavcodec.so.60 libavformat.so.60 libavdevice.so.60 libavfilter.so.9 libswscale.so.7 libswresample.so.4) + elseif (ffmpeg_major_version EQUAL 7) + set(library_file_names libavutil.so.59 libavcodec.so.61 libavformat.so.61 libavdevice.so.61 libavfilter.so.10 libswscale.so.8 libswresample.so.5) + elseif (ffmpeg_major_version EQUAL 8) + set(library_file_names libavutil.so.60 libavcodec.so.62 libavformat.so.62 libavdevice.so.62 libavfilter.so.11 libswscale.so.9 libswresample.so.6) + endif() + elseif (APPLE) + if (ffmpeg_major_version EQUAL 4) + set(library_file_names libavutil.56.dylib libavcodec.58.dylib libavformat.58.dylib libavdevice.58.dylib libavfilter.7.dylib libswscale.5.dylib libswresample.3.dylib) + elseif (ffmpeg_major_version EQUAL 5) + set(library_file_names libavutil.57.dylib libavcodec.59.dylib libavformat.59.dylib libavdevice.59.dylib libavfilter.8.dylib libswscale.6.dylib libswresample.4.dylib) + elseif (ffmpeg_major_version EQUAL 6) + set(library_file_names libavutil.58.dylib libavcodec.60.dylib libavformat.60.dylib libavdevice.60.dylib libavfilter.9.dylib libswscale.7.dylib libswresample.4.dylib) + elseif (ffmpeg_major_version EQUAL 7) + set(library_file_names libavutil.59.dylib libavcodec.61.dylib libavformat.61.dylib libavdevice.61.dylib libavfilter.10.dylib libswscale.8.dylib libswresample.5.dylib) + elseif (ffmpeg_major_version EQUAL 8) + set(library_file_names libavutil.60.dylib libavcodec.62.dylib libavformat.62.dylib libavdevice.62.dylib libavfilter.11.dylib libswscale.9.dylib libswresample.6.dylib) + endif() + elseif (WIN32) + set(library_file_names avutil.lib avcodec.lib avformat.lib avdevice.lib avfilter.lib swscale.lib swresample.lib) + else() + message(FATAL_ERROR "Unsupported operating system: ${CMAKE_SYSTEM_NAME}") + endif() + + set(target "torchcodec::ffmpeg${ffmpeg_major_version}") + set(include_dir "${prefix}/include") + if (LINUX OR APPLE) + set(lib_dir "${prefix}/lib") + elseif (WIN32) + set(lib_dir "${prefix}/bin") + else() + message(FATAL_ERROR "Unsupported operating system: ${CMAKE_SYSTEM_NAME}") + endif() + + list( + TRANSFORM library_file_names + PREPEND ${lib_dir}/ + OUTPUT_VARIABLE lib_paths + ) + + message("Adding ${target} target") + # Verify that ffmpeg includes and libraries actually exist. + foreach (path IN LISTS include_dir lib_paths) + if (NOT EXISTS "${path}") + message(FATAL_ERROR "${path} does not exist") + endif() + endforeach() + + # Actually define the target + add_library(${target} INTERFACE IMPORTED) + target_include_directories(${target} INTERFACE ${include_dir}) + target_link_libraries(${target} INTERFACE ${lib_paths}) +endfunction() + +# Create and expose torchcodec::ffmpeg${ffmpeg_major_version} target which can +# then be used as a dependency in other targets. +# The FFmpeg installation is found by pkg-config. +function(add_ffmpeg_target_with_pkg_config ret_ffmpeg_major_version_var) + find_package(PkgConfig REQUIRED) + pkg_check_modules(TORCHCODEC_LIBAV REQUIRED IMPORTED_TARGET + libavdevice + libavfilter + libavformat + libavcodec + libavutil + libswresample + libswscale + ) + + # Split libavcodec's version string by '.' and convert it to a list + # The TORCHCODEC_LIBAV_libavcodec_VERSION is made available by pkg-config. + string(REPLACE "." ";" libavcodec_version_list ${TORCHCODEC_LIBAV_libavcodec_VERSION}) + # Get the first element of the list, which is the major version + list(GET libavcodec_version_list 0 libavcodec_major_version) + + if (${libavcodec_major_version} STREQUAL "58") + set(ffmpeg_major_version "4") + elseif (${libavcodec_major_version} STREQUAL "59") + set(ffmpeg_major_version "5") + elseif (${libavcodec_major_version} STREQUAL "60") + set(ffmpeg_major_version "6") + elseif (${libavcodec_major_version} STREQUAL "61") + set(ffmpeg_major_version "7") + elseif (${libavcodec_major_version} STREQUAL "62") + set(ffmpeg_major_version "8") + else() + message(FATAL_ERROR "Unsupported libavcodec version: ${libavcodec_major_version}") + endif() + + message("Adding torchcodec::ffmpeg${ffmpeg_major_version} target") + add_library(torchcodec::ffmpeg${ffmpeg_major_version} ALIAS PkgConfig::TORCHCODEC_LIBAV) + set(${ret_ffmpeg_major_version_var} ${ffmpeg_major_version} PARENT_SCOPE) +endfunction() diff --git a/src/torchcodec/transforms/__init__.py b/src/torchcodec/transforms/__init__.py new file mode 100644 index 000000000..4720e4ef4 --- /dev/null +++ b/src/torchcodec/transforms/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from ._decoder_transforms import ( # noqa + CenterCrop, + DecoderTransform, + RandomCrop, + Resize, +) diff --git a/src/torchcodec/transforms/_decoder_transforms.py b/src/torchcodec/transforms/_decoder_transforms.py new file mode 100644 index 000000000..378e5235d --- /dev/null +++ b/src/torchcodec/transforms/_decoder_transforms.py @@ -0,0 +1,375 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from abc import ABC, abstractmethod +from collections.abc import Sequence +from types import ModuleType + +import torch +from torch import nn + + +class DecoderTransform(ABC): + """Base class for all decoder transforms. + + A *decoder transform* is a transform that is applied by the decoder before + returning the decoded frame. Applying decoder transforms to frames + should be both faster and more memory efficient than receiving normally + decoded frames and applying the same kind of transform. + + Most ``DecoderTransform`` objects have a complementary transform in TorchVision, + specificially in `torchvision.transforms.v2 `_. + For such transforms, we ensure that: + + 1. The names are the same. + 2. Default behaviors are the same. + 3. The parameters for the ``DecoderTransform`` object are a subset of the + TorchVision :class:`~torchvision.transforms.v2.Transform` object. + 4. Parameters with the same name control the same behavior and accept a + subset of the same types. + 5. The difference between the frames returned by a decoder transform and + the complementary TorchVision transform are such that a model should + not be able to tell the difference. + """ + + @abstractmethod + def _make_transform_spec(self, input_dims: tuple[int | None, int | None]) -> str: + """Makes the transform spec that is used by the `VideoDecoder`. + + Args: + input_dims (tuple[int | None, int | None]): The dimensions of + the input frame in the form (height, width). We cannot know the + dimensions at object construction time because it's dependent on + the video being decoded and upstream transforms in the same + transform pipeline. Not all transforms need to know this; those + that don't will ignore it. The individual values in the tuple are + optional because the original values come from file metadata which + may be missing. We maintain the optionality throughout the APIs so + that we can decide as late as possible that it's necessary for the + values to exist. That is, if the values are missing from the + metadata and we have transforms which ignore the input dimensions, + we want that to still work. + + Note: This method is the moral equivalent of TorchVision's + `Transform.make_params()`. + + Returns: + str: A string which contains the spec for the transform that the + `VideoDecoder` knows what to do with. + """ + pass + + def _get_output_dims(self) -> tuple[int | None, int | None] | None: + """Get the dimensions of the output frame. + + Transforms that change the frame dimensions need to override this + method. Transforms that don't change the frame dimensions can rely on + this default implementation. + + Returns: + tuple[int | None, int | None] | None: The output dimensions. + - None: The output dimensions are the same as the input dimensions. + - (int, int): The (height, width) of the output frame. + """ + return None + + +def import_torchvision_transforms_v2() -> ModuleType: + try: + from torchvision.transforms import v2 + except ImportError as e: + raise RuntimeError( + "Cannot import TorchVision; this should never happen, please report a bug." + ) from e + return v2 + + +class Resize(DecoderTransform): + """Resize the decoded frame to a given size. + + Complementary TorchVision transform: :class:`~torchvision.transforms.v2.Resize`. + Interpolation is always bilinear. Anti-aliasing is always on. + + Args: + size (Sequence[int]): Desired output size. Must be a sequence of + the form (height, width). + """ + + def __init__(self, size: Sequence[int]): + if len(size) != 2: + raise ValueError( + "Resize transform must have a (height, width) " + f"pair for the size, got {size}." + ) + self.size = size + + def _make_transform_spec(self, input_dims: tuple[int | None, int | None]) -> str: + return f"resize, {self.size[0]}, {self.size[1]}" + + def _get_output_dims(self) -> tuple[int | None, int | None] | None: + return (self.size[0], self.size[1]) + + @classmethod + def _from_torchvision(cls, tv_resize: nn.Module): + v2 = import_torchvision_transforms_v2() + + assert isinstance(tv_resize, v2.Resize) + + if tv_resize.interpolation is not v2.InterpolationMode.BILINEAR: + raise ValueError( + "TorchVision Resize transform must use bilinear interpolation." + ) + if tv_resize.antialias is False: + raise ValueError( + "TorchVision Resize transform must have antialias enabled." + ) + if tv_resize.size is None: + raise ValueError("TorchVision Resize transform must have a size specified.") + if len(tv_resize.size) != 2: + raise ValueError( + "TorchVision Resize transform must have a (height, width) " + f"pair for the size, got {tv_resize.size}." + ) + return cls(size=tv_resize.size) + + +class CenterCrop(DecoderTransform): + """Crop the decoded frame to a given size in the center of the frame. + + Complementary TorchVision transform: :class:`~torchvision.transforms.v2.CenterCrop`. + + Args: + size (Sequence[int]): Desired output size. Must be a sequence of + the form (height, width). + """ + + def __init__(self, size: Sequence[int]): + if len(size) != 2: + raise ValueError( + "CenterCrop transform must have a (height, width) " + f"pair for the size, got {size}." + ) + self.size = size + + def _make_transform_spec(self, input_dims: tuple[int | None, int | None]) -> str: + return f"center_crop, {self.size[0]}, {self.size[1]}" + + def _get_output_dims(self) -> tuple[int | None, int | None] | None: + return (self.size[0], self.size[1]) + + @classmethod + def _from_torchvision( + cls, + tv_center_crop: nn.Module, + ): + v2 = import_torchvision_transforms_v2() + + if not isinstance(tv_center_crop, v2.CenterCrop): + raise ValueError( + "Transform must be TorchVision's CenterCrop, " + f"it is instead {type(tv_center_crop).__name__}. " + "This should never happen, please report a bug." + ) + + if len(tv_center_crop.size) != 2: + raise ValueError( + "TorchVision CenterCrop transform must have a (height, width) " + f"pair for the size, got {tv_center_crop.size}." + ) + + return cls(size=tv_center_crop.size) + + +class RandomCrop(DecoderTransform): + """Crop the decoded frame to a given size at a random location in the frame. + + Complementary TorchVision transform: :class:`~torchvision.transforms.v2.RandomCrop`. + Padding of all kinds is disabled. The random location within the frame is + determined during the initialization of the + :class:`~torchcodec.decoders.VideoDecoder` object that owns this transform. + As a consequence, each decoded frame in the video will be cropped at the + same location. Videos with variable resolution may result in undefined + behavior. + + Args: + size (Sequence[int]): Desired output size. Must be a sequence of + the form (height, width). + """ + + def __init__(self, size: Sequence[int]): + if len(size) != 2: + raise ValueError( + "RandomCrop transform must have a (height, width) " + f"pair for the size, got {size}." + ) + self.size = size + + def _make_transform_spec(self, input_dims: tuple[int | None, int | None]) -> str: + height, width = input_dims + if height is None: + raise ValueError( + "Video metadata has no height. " + "RandomCrop can only be used when input frame dimensions are known." + ) + if width is None: + raise ValueError( + "Video metadata has no width. " + "RandomCrop can only be used when input frame dimensions are known." + ) + + # Note: This logic below must match the logic in + # torchvision.transforms.v2.RandomCrop.make_params(). Given + # the same seed, they should get the same result. This is an + # API guarantee with our users. + if height < self.size[0] or width < self.size[1]: + raise ValueError( + f"Input dimensions {input_dims} are smaller than the crop size {self.size}." + ) + + top = int(torch.randint(0, height - self.size[0] + 1, size=()).item()) + left = int(torch.randint(0, width - self.size[1] + 1, size=()).item()) + + return f"crop, {self.size[0]}, {self.size[1]}, {left}, {top}" + + def _get_output_dims(self) -> tuple[int | None, int | None] | None: + return (self.size[0], self.size[1]) + + @classmethod + def _from_torchvision( + cls, + tv_random_crop: nn.Module, + ): + v2 = import_torchvision_transforms_v2() + + if not isinstance(tv_random_crop, v2.RandomCrop): + raise ValueError( + "Transform must be TorchVision's RandomCrop, " + f"it is instead {type(tv_random_crop).__name__}. " + "This should never happen, please report a bug." + ) + + if tv_random_crop.padding is not None: + raise ValueError( + "TorchVision RandomCrop transform must not specify padding." + ) + + if tv_random_crop.pad_if_needed is True: + raise ValueError( + "TorchVision RandomCrop transform must not specify pad_if_needed." + ) + + if tv_random_crop.fill != 0: + raise ValueError("TorchVision RandomCrop fill must be 0.") + + if tv_random_crop.padding_mode != "constant": + raise ValueError("TorchVision RandomCrop padding_mode must be constant.") + + if len(tv_random_crop.size) != 2: + raise ValueError( + "TorchVision RandcomCrop transform must have a (height, width) " + f"pair for the size, got {tv_random_crop.size}." + ) + + return cls(size=tv_random_crop.size) + + +def _make_transform_specs( + transforms: Sequence[DecoderTransform | nn.Module] | None, + input_dims: tuple[int | None, int | None], +) -> str: + """Given a sequence of transforms, turn those into the specification string + the core API expects. + + Args: + transforms: Optional sequence of transform objects. The objects can be + one of two types: + 1. torchcodec.transforms.DecoderTransform + 2. torchvision.transforms.v2.Transform, but our type annotation + only mentions its base, nn.Module. We don't want to take a + hard dependency on TorchVision. + input_dims: Optional (height, width) pair. Note that only some + transforms need to know the dimensions. If the user provides + transforms that don't need to know the dimensions, and that metadata + is missing, everything should still work. That means we assert their + existence as late as possible. + + Returns: + String of transforms in the format the core API expects: transform + specifications separate by semicolons. + """ + if transforms is None: + return "" + + try: + from torchvision.transforms import v2 + + tv_available = True + except ImportError: + tv_available = False + + # The following loop accomplishes two tasks: + # + # 1. Converts the transform to a DecoderTransform, if necessary. We + # accept TorchVision transform objects and they must be converted + # to their matching DecoderTransform. + # 2. Calculates what the input dimensions are to each transform. + # + # The order in our transforms list is semantically meaningful, as we + # actually have a pipeline where the output of one transform is the input to + # the next. For example, if we have the transforms list [A, B, C, D], then + # we should understand that as: + # + # A -> B -> C -> D + # + # Where the frame produced by A is the input to B, the frame produced by B + # is the input to C, etc. This particularly matters for frame dimensions. + # Transforms can both: + # + # 1. Produce frames with arbitrary dimensions. + # 2. Rely on their input frame's dimensions to calculate ahead-of-time + # what their runtime behavior will be. + # + # The consequence of the above facts is that we need to statically track + # frame dimensions in the pipeline while we pre-process it. The input + # frame's dimensions to A, our first transform, is always what we know from + # our metadata. For each transform, we always calculate its output + # dimensions from its input dimensions. We store these with the converted + # transform, to be all used together when we generate the specs. + converted_transforms: list[ + tuple[ + DecoderTransform, + # A (height, width) pair where the values may be missing. + tuple[int | None, int | None], + ] + ] = [] + curr_input_dims = input_dims + for transform in transforms: + if not isinstance(transform, DecoderTransform): + if not tv_available: + raise ValueError( + f"The supplied transform, {transform}, is not a TorchCodec " + " DecoderTransform. TorchCodec also accepts TorchVision " + "v2 transforms, but TorchVision is not installed." + ) + elif isinstance(transform, v2.Resize): + transform = Resize._from_torchvision(transform) + elif isinstance(transform, v2.CenterCrop): + transform = CenterCrop._from_torchvision(transform) + elif isinstance(transform, v2.RandomCrop): + transform = RandomCrop._from_torchvision(transform) + else: + raise ValueError( + f"Unsupported transform: {transform}. Transforms must be " + "either a TorchCodec DecoderTransform or a TorchVision " + "v2 transform." + ) + + converted_transforms.append((transform, curr_input_dims)) + output_dims = transform._get_output_dims() + curr_input_dims = output_dims if output_dims is not None else curr_input_dims + + return ";".join([t._make_transform_spec(dims) for t, dims in converted_transforms]) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 988def933..b9d755a2a 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.14) include(CMakePrintHelpers) project(TorchCodecTests) -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED) find_package(Torch REQUIRED) @@ -34,3 +34,22 @@ target_link_libraries( include(GoogleTest) gtest_discover_tests(VideoDecoderTest) + + +add_executable( + MetadataTest + MetadataTest.cpp +) + +target_include_directories(MetadataTest SYSTEM PRIVATE ${TORCH_INCLUDE_DIRS}) +target_include_directories(MetadataTest SYSTEM PRIVATE ${libav_include_dirs}) +target_include_directories(MetadataTest PRIVATE ../) + +target_link_libraries( + MetadataTest + ${libtorchcodec_library_name} + ${libtorchcodec_custom_ops_name} + GTest::gtest_main +) + +gtest_discover_tests(MetadataTest) diff --git a/test/MetadataTest.cpp b/test/MetadataTest.cpp new file mode 100644 index 000000000..3a65b6dc5 --- /dev/null +++ b/test/MetadataTest.cpp @@ -0,0 +1,191 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "src/torchcodec/_core/Metadata.h" + +#include + +namespace facebook::torchcodec { + +// Test that num_frames_from_content always has priority when accessing +// getNumFrames() +TEST(MetadataTest, NumFramesFallbackPriority) { + // in exact mode, both header and content available + { + StreamMetadata metadata; + metadata.numFramesFromHeader = 10; + metadata.numFramesFromContent = 20; + metadata.durationSecondsFromHeader = 4.0; + metadata.averageFpsFromHeader = 30.0; + + EXPECT_EQ(metadata.getNumFrames(SeekMode::exact), 20); + } + + // in exact mode, only content available + { + StreamMetadata metadata; + metadata.numFramesFromHeader = std::nullopt; + metadata.numFramesFromContent = 10; + metadata.durationSecondsFromHeader = 4.0; + metadata.averageFpsFromHeader = 30.0; + + EXPECT_EQ(metadata.getNumFrames(SeekMode::exact), 10); + } + + // in approximate mode, header should be used + { + StreamMetadata metadata; + metadata.numFramesFromHeader = 10; + metadata.numFramesFromContent = std::nullopt; + metadata.durationSecondsFromHeader = 4.0; + metadata.averageFpsFromHeader = 30.0; + + EXPECT_EQ(metadata.getNumFrames(SeekMode::approximate), 10); + } +} + +// Test that if num_frames_from_content and num_frames_from_header are missing, +// getNumFrames() is calculated using average_fps_from_header and +// duration_seconds_from_header in approximate mode +TEST(MetadataTest, CalculateNumFramesUsingFpsAndDuration) { + // both fps and duration available + { + StreamMetadata metadata; + metadata.numFramesFromHeader = std::nullopt; + metadata.numFramesFromContent = std::nullopt; + metadata.averageFpsFromHeader = 60.0; + metadata.durationSecondsFromHeader = 10.0; + + EXPECT_EQ(metadata.getNumFrames(SeekMode::approximate), 600); + } + + // fps available but duration missing + { + StreamMetadata metadata; + metadata.numFramesFromHeader = std::nullopt; + metadata.numFramesFromContent = std::nullopt; + metadata.averageFpsFromHeader = 60.0; + metadata.durationSecondsFromHeader = std::nullopt; + + EXPECT_EQ(metadata.getNumFrames(SeekMode::approximate), std::nullopt); + } + + // duration available but fps missing + { + StreamMetadata metadata; + metadata.numFramesFromHeader = std::nullopt; + metadata.numFramesFromContent = std::nullopt; + metadata.averageFpsFromHeader = std::nullopt; + metadata.durationSecondsFromHeader = 10.0; + + EXPECT_EQ(metadata.getNumFrames(SeekMode::approximate), std::nullopt); + } + + // both missing + { + StreamMetadata metadata; + metadata.numFramesFromHeader = std::nullopt; + metadata.numFramesFromContent = std::nullopt; + metadata.averageFpsFromHeader = std::nullopt; + metadata.durationSecondsFromHeader = std::nullopt; + + EXPECT_EQ(metadata.getNumFrames(SeekMode::approximate), std::nullopt); + } +} + +// Test that using begin_stream_seconds_from_content and +// end_stream_seconds_from_content to calculate getDurationSeconds() has +// priority. If either value is missing, duration_seconds_from_header is used. +TEST(MetadataTest, DurationSecondsFallback) { + // in exact mode, both begin and end content available, should calculate from + // them + { + StreamMetadata metadata; + metadata.durationSecondsFromHeader = 60.0; + metadata.beginStreamPtsSecondsFromContent = 5.0; + metadata.endStreamPtsSecondsFromContent = 20.0; + + EXPECT_NEAR( + metadata.getDurationSeconds(SeekMode::exact).value(), 15.0, 1e-6); + } + + // in exact mode, only content values, no header + { + StreamMetadata metadata; + metadata.durationSecondsFromHeader = std::nullopt; + metadata.beginStreamPtsSecondsFromContent = 0.0; + metadata.endStreamPtsSecondsFromContent = 10.0; + + EXPECT_NEAR( + metadata.getDurationSeconds(SeekMode::exact).value(), 10.0, 1e-6); + } + + // in approximate mode, header value takes priority (ignores content) + { + StreamMetadata metadata; + metadata.durationSecondsFromHeader = 60.0; + metadata.beginStreamPtsSecondsFromContent = 5.0; + metadata.endStreamPtsSecondsFromContent = 20.0; + + EXPECT_NEAR( + metadata.getDurationSeconds(SeekMode::approximate).value(), 60.0, 1e-6); + } +} + +// Test that duration_seconds is calculated using average_fps_from_header and +// num_frames_from_header if duration_seconds_from_header is missing. +TEST(MetadataTest, CalculateDurationSecondsUsingFpsAndNumFrames) { + // in approximate mode, both num_frames and fps available + { + StreamMetadata metadata; + metadata.durationSecondsFromHeader = std::nullopt; + metadata.numFramesFromHeader = 100; + metadata.averageFpsFromHeader = 10.0; + metadata.beginStreamPtsSecondsFromContent = std::nullopt; + metadata.endStreamPtsSecondsFromContent = std::nullopt; + + EXPECT_NEAR( + metadata.getDurationSeconds(SeekMode::approximate).value(), 10.0, 1e-6); + } + + // in approximate mode, num_frames available but fps missing + { + StreamMetadata metadata; + metadata.durationSecondsFromHeader = std::nullopt; + metadata.numFramesFromHeader = 100; + metadata.averageFpsFromHeader = std::nullopt; + metadata.beginStreamPtsSecondsFromContent = std::nullopt; + metadata.endStreamPtsSecondsFromContent = std::nullopt; + + EXPECT_EQ(metadata.getDurationSeconds(SeekMode::approximate), std::nullopt); + } + + // in approximate mode, fps available but num_frames missing + { + StreamMetadata metadata; + metadata.durationSecondsFromHeader = std::nullopt; + metadata.numFramesFromHeader = std::nullopt; + metadata.averageFpsFromHeader = 10.0; + metadata.beginStreamPtsSecondsFromContent = std::nullopt; + metadata.endStreamPtsSecondsFromContent = std::nullopt; + + EXPECT_EQ(metadata.getDurationSeconds(SeekMode::approximate), std::nullopt); + } + + // in approximate mode, both missing + { + StreamMetadata metadata; + metadata.durationSecondsFromHeader = std::nullopt; + metadata.numFramesFromHeader = std::nullopt; + metadata.averageFpsFromHeader = std::nullopt; + metadata.beginStreamPtsSecondsFromContent = std::nullopt; + metadata.endStreamPtsSecondsFromContent = std::nullopt; + + EXPECT_EQ(metadata.getDurationSeconds(SeekMode::approximate), std::nullopt); + } +} + +} // namespace facebook::torchcodec diff --git a/test/VideoDecoderTest.cpp b/test/VideoDecoderTest.cpp index 1481d3a2a..346679f75 100644 --- a/test/VideoDecoderTest.cpp +++ b/test/VideoDecoderTest.cpp @@ -27,6 +27,15 @@ C10_DEFINE_bool( namespace facebook::torchcodec { +inline torch::stable::Tensor toStableTensor(const torch::Tensor& tensor) { + torch::Tensor* p = new torch::Tensor(tensor); + return torch::stable::Tensor(reinterpret_cast(p)); +} + +inline torch::Tensor toATenTensor(const torch::stable::Tensor& t) { + return *reinterpret_cast(t.get()); +} + std::string getResourcePath(const std::string& filename) { #ifdef FBCODE_BUILD std::string filepath = "pytorch/torchcodec/test/resources/" + filename; @@ -59,15 +68,16 @@ class SingleStreamDecoderTest : public testing::TestWithParam { char* data = new char[length]; std::memcpy(data, content_.data(), length); auto deleter = [data](void*) { delete[] data; }; - at::Tensor tensor = at::from_blob( + torch::Tensor tensor = torch::from_blob( static_cast(data), {length}, deleter, {torch::kUInt8}); - auto contextHolder = std::make_unique(tensor); + auto contextHolder = + std::make_unique(toStableTensor(tensor)); return std::make_unique( - std::move(contextHolder), SingleStreamDecoder::SeekMode::approximate); + std::move(contextHolder), SeekMode::approximate); } else { return std::make_unique( - filepath, SingleStreamDecoder::SeekMode::approximate); + filepath, SeekMode::approximate); } } @@ -106,7 +116,8 @@ TEST_P(SingleStreamDecoderTest, ReturnsFpsAndDurationForVideoInMetadata) { } TEST(SingleStreamDecoderTest, MissingVideoFileThrowsException) { - EXPECT_THROW(SingleStreamDecoder("/this/file/does/not/exist"), c10::Error); + EXPECT_THROW( + SingleStreamDecoder("/this/file/does/not/exist"), std::runtime_error); } void dumpTensorToDisk( @@ -154,7 +165,7 @@ TEST(SingleStreamDecoderTest, RespectsOutputTensorDimensionOrderFromOptions) { videoStreamOptions.dimensionOrder = "NHWC"; std::vector transforms; decoder->addVideoStream(-1, transforms, videoStreamOptions); - torch::Tensor tensor = decoder->getNextFrame().data; + auto tensor = toATenTensor(decoder->getNextFrame().data); EXPECT_EQ(tensor.sizes(), std::vector({270, 480, 3})); } @@ -165,11 +176,11 @@ TEST_P(SingleStreamDecoderTest, ReturnsFirstTwoFramesOfVideo) { std::vector transforms; ourDecoder->addVideoStream(-1, transforms); auto output = ourDecoder->getNextFrame(); - torch::Tensor tensor0FromOurDecoder = output.data; + torch::Tensor tensor0FromOurDecoder = toATenTensor(output.data); EXPECT_EQ(tensor0FromOurDecoder.sizes(), std::vector({3, 270, 480})); EXPECT_EQ(output.ptsSeconds, 0.0); output = ourDecoder->getNextFrame(); - torch::Tensor tensor1FromOurDecoder = output.data; + torch::Tensor tensor1FromOurDecoder = toATenTensor(output.data); EXPECT_EQ(tensor1FromOurDecoder.sizes(), std::vector({3, 270, 480})); EXPECT_EQ(output.ptsSeconds, 1'001. / 30'000); @@ -188,8 +199,6 @@ TEST_P(SingleStreamDecoderTest, ReturnsFirstTwoFramesOfVideo) { torch::allclose(tensor1FromOurDecoder, tensor1FromFFMPEG, 0.1, 20)); if (FLAGS_dump_frames_for_debugging) { - dumpTensorToDisk(tensor0FromFFMPEG, "tensor0FromFFMPEG.pt"); - dumpTensorToDisk(tensor1FromFFMPEG, "tensor1FromFFMPEG.pt"); dumpTensorToDisk(tensor0FromOurDecoder, "tensor0FromOurDecoder.pt"); dumpTensorToDisk(tensor1FromOurDecoder, "tensor1FromOurDecoder.pt"); } @@ -205,9 +214,9 @@ TEST_P(SingleStreamDecoderTest, DecodesFramesInABatchInNCHW) { std::vector transforms; ourDecoder->addVideoStream(bestVideoStreamIndex, transforms); // Frame with index 180 corresponds to timestamp 6.006. - auto frameIndices = torch::tensor({0, 180}); + auto frameIndices = toStableTensor(torch::tensor({0, 180})); auto output = ourDecoder->getFramesAtIndices(frameIndices); - auto tensor = output.data; + auto tensor = toATenTensor(output.data); EXPECT_EQ(tensor.sizes(), std::vector({2, 3, 270, 480})); torch::Tensor tensor0FromFFMPEG = @@ -232,9 +241,9 @@ TEST_P(SingleStreamDecoderTest, DecodesFramesInABatchInNHWC) { ourDecoder->addVideoStream( bestVideoStreamIndex, transforms, videoStreamOptions); // Frame with index 180 corresponds to timestamp 6.006. - auto frameIndices = torch::tensor({0, 180}); + auto frameIndices = toStableTensor(torch::tensor({0, 180})); auto output = ourDecoder->getFramesAtIndices(frameIndices); - auto tensor = output.data; + auto tensor = toATenTensor(output.data); EXPECT_EQ(tensor.sizes(), std::vector({2, 270, 480, 3})); torch::Tensor tensor0FromFFMPEG = @@ -300,15 +309,16 @@ TEST_P(SingleStreamDecoderTest, SeeksToFrameWithSpecificPts) { ourDecoder->addVideoStream(-1, transforms); ourDecoder->setCursorPtsInSeconds(6.0); auto output = ourDecoder->getNextFrame(); - torch::Tensor tensor6FromOurDecoder = output.data; + torch::Tensor tensor6FromOurDecoder = toATenTensor(output.data); EXPECT_EQ(output.ptsSeconds, 180'180. / 30'000); torch::Tensor tensor6FromFFMPEG = readTensorFromDisk("nasa_13013.mp4.time6.000000.pt"); EXPECT_TRUE(torch::equal(tensor6FromOurDecoder, tensor6FromFFMPEG)); EXPECT_EQ(ourDecoder->getDecodeStats().numSeeksAttempted, 1); - // We skipped the seek since timestamp=6 and timestamp=0 share the same - // keyframe. - EXPECT_EQ(ourDecoder->getDecodeStats().numSeeksSkipped, 1); + // lastDecodedAvFramePts_ is initialized to INT64_MIN, so the + // first seek is always performed even though timestamp=6 and timestamp=0 + // share the same keyframe. + EXPECT_EQ(ourDecoder->getDecodeStats().numSeeksSkipped, 0); // There are about 180 packets/frames between timestamp=0 and timestamp=6 at // ~30 fps. EXPECT_GT(ourDecoder->getDecodeStats().numPacketsRead, 180); @@ -316,7 +326,7 @@ TEST_P(SingleStreamDecoderTest, SeeksToFrameWithSpecificPts) { ourDecoder->setCursorPtsInSeconds(6.1); output = ourDecoder->getNextFrame(); - torch::Tensor tensor61FromOurDecoder = output.data; + auto tensor61FromOurDecoder = toATenTensor(output.data); EXPECT_EQ(output.ptsSeconds, 183'183. / 30'000); torch::Tensor tensor61FromFFMPEG = readTensorFromDisk("nasa_13013.mp4.time6.100000.pt"); @@ -336,7 +346,7 @@ TEST_P(SingleStreamDecoderTest, SeeksToFrameWithSpecificPts) { ourDecoder->setCursorPtsInSeconds(10.0); output = ourDecoder->getNextFrame(); - torch::Tensor tensor10FromOurDecoder = output.data; + auto tensor10FromOurDecoder = toATenTensor(output.data); EXPECT_EQ(output.ptsSeconds, 300'300. / 30'000); torch::Tensor tensor10FromFFMPEG = readTensorFromDisk("nasa_13013.mp4.time10.000000.pt"); @@ -353,7 +363,7 @@ TEST_P(SingleStreamDecoderTest, SeeksToFrameWithSpecificPts) { ourDecoder->setCursorPtsInSeconds(6.0); output = ourDecoder->getNextFrame(); - tensor6FromOurDecoder = output.data; + tensor6FromOurDecoder = toATenTensor(output.data); EXPECT_EQ(output.ptsSeconds, 180'180. / 30'000); EXPECT_TRUE(torch::equal(tensor6FromOurDecoder, tensor6FromFFMPEG)); EXPECT_EQ(ourDecoder->getDecodeStats().numSeeksAttempted, 1); @@ -368,7 +378,7 @@ TEST_P(SingleStreamDecoderTest, SeeksToFrameWithSpecificPts) { constexpr double kPtsOfLastFrameInVideoStream = 389'389. / 30'000; // ~12.9 ourDecoder->setCursorPtsInSeconds(kPtsOfLastFrameInVideoStream); output = ourDecoder->getNextFrame(); - torch::Tensor tensor7FromOurDecoder = output.data; + auto tensor7FromOurDecoder = toATenTensor(output.data); EXPECT_EQ(output.ptsSeconds, 389'389. / 30'000); torch::Tensor tensor7FromFFMPEG = readTensorFromDisk("nasa_13013.mp4.time12.979633.pt"); @@ -390,7 +400,8 @@ TEST_P(SingleStreamDecoderTest, SeeksToFrameWithSpecificPts) { TEST_P(SingleStreamDecoderTest, PreAllocatedTensorFilterGraph) { std::string path = getResourcePath("nasa_13013.mp4"); - auto preAllocatedOutputTensor = torch::empty({270, 480, 3}, {torch::kUInt8}); + auto preAllocatedOutputTensor = + toStableTensor(torch::empty({270, 480, 3}, {torch::kUInt8})); std::unique_ptr ourDecoder = SingleStreamDecoderTest::createDecoderFromPath(path, GetParam()); @@ -410,7 +421,8 @@ TEST_P(SingleStreamDecoderTest, PreAllocatedTensorFilterGraph) { TEST_P(SingleStreamDecoderTest, PreAllocatedTensorSwscale) { std::string path = getResourcePath("nasa_13013.mp4"); - auto preAllocatedOutputTensor = torch::empty({270, 480, 3}, {torch::kUInt8}); + auto preAllocatedOutputTensor = + toStableTensor(torch::empty({270, 480, 3}, {torch::kUInt8})); std::unique_ptr ourDecoder = SingleStreamDecoderTest::createDecoderFromPath(path, GetParam()); diff --git a/test/conftest.py b/test/conftest.py index bef5291e5..0588f83be 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -4,12 +4,17 @@ import pytest import torch +from .utils import in_fbcode + def pytest_configure(config): # register an additional marker (see pytest_collection_modifyitems) config.addinivalue_line( "markers", "needs_cuda: mark for tests that rely on a CUDA device" ) + config.addinivalue_line( + "markers", "needs_ffmpeg_cli: mark for tests that rely on ffmpeg" + ) def pytest_collection_modifyitems(items): @@ -28,6 +33,27 @@ def pytest_collection_modifyitems(items): # 'needs_cuda' mark, and the ones with device == 'cpu' won't have the # mark. needs_cuda = item.get_closest_marker("needs_cuda") is not None + needs_ffmpeg_cli = item.get_closest_marker("needs_ffmpeg_cli") is not None + has_skip_marker = item.get_closest_marker("skip") is not None + + # For skipif, the marker is always present regardless of whether the + # condition is True or False, so we must check the actual condition. + skipif_condition_is_true = any( + skipif_marker.args[0] for skipif_marker in item.iter_markers("skipif") + ) + + # If we need to conditionally skip tests based on a dependency, we should follow + # the decorator pattern used by needs_cuda and needs_ffmpeg_cli: + # 1. Define a custom marker in pytest_configure() above + # 2. Create a decorator function in utils.py (e.g., @needs_my_dependency) + # 3. Handle the marker here in pytest_collection_modifyitems() + # This keeps our skip logic centralized + + if in_fbcode(): + # fbcode doesn't like skipping tests, so instead we just don't collect the test + # so that they don't even "exist", hence the continue statements. + if needs_ffmpeg_cli or has_skip_marker or skipif_condition_is_true: + continue if ( needs_cuda diff --git a/test/generate_reference_resources.py b/test/generate_reference_resources.py index 953fb996e..7d28e9993 100644 --- a/test/generate_reference_resources.py +++ b/test/generate_reference_resources.py @@ -6,7 +6,6 @@ import subprocess from pathlib import Path -from typing import Optional import numpy as np @@ -39,7 +38,7 @@ def generate_frame_by_index( *, frame_index: int, stream_index: int, - filters: Optional[str] = None, + filters: str | None = None, ) -> None: # Note that we are using 0-based index naming. As a result, we are # generating files one-by-one, giving the actual file name that we want. @@ -51,16 +50,16 @@ def generate_frame_by_index( ) output_bmp = f"{base_path}.bmp" - # Note that we have an exlicit format conversion to rgb24 in our filtergraph specification, - # which always happens BEFORE any of the filters that we receive as input. We do this to - # ensure that the color conversion happens BEFORE the filters, matching the behavior of the - # torchcodec filtergraph implementation. - # - # Not doing this would result in the color conversion happening AFTER the filters, which - # would result in different color values for the same frame. - filtergraph = f"select='eq(n\\,{frame_index})',format=rgb24" + # Note that we have an exlicit format conversion to rgb24 in our filtergraph + # specification, and we always place the user-supplied filters AFTER the + # format conversion. We do this to ensure that the filters are applied in + # RGB24 colorspace, which matches TorchCodec's behavior. + select = f"select='eq(n\\,{frame_index})'" + format = "format=rgb24" if filters is not None: - filtergraph = filtergraph + f",{filters}" + filtergraph = ",".join([select, format, filters]) + else: + filtergraph = ",".join([select, format]) cmd = [ "ffmpeg", @@ -99,7 +98,7 @@ def generate_frame_by_timestamp( convert_image_to_tensor(output_path) -def generate_nasa_13013_references(): +def generate_nasa_13013_references_by_index(): # Note: The naming scheme used here must match the naming scheme used to load # tensors in ./utils.py. streams = [0, 3] @@ -108,6 +107,8 @@ def generate_nasa_13013_references(): for frame in frames: generate_frame_by_index(NASA_VIDEO, frame_index=frame, stream_index=stream) + +def generate_nasa_13013_references_by_timestamp(): # Extract individual frames at specific timestamps, including the last frame of the video. seek_timestamp = [6.0, 6.1, 10.0, 12.979633] timestamp_name = [f"{seek_timestamp:06f}" for seek_timestamp in seek_timestamp] @@ -115,6 +116,8 @@ def generate_nasa_13013_references(): output_bmp = f"{NASA_VIDEO.path}.time{name}.bmp" generate_frame_by_timestamp(NASA_VIDEO.path, timestamp, output_bmp) + +def generate_nasa_13013_references_crop(): # Extract frames with specific filters. We have tests that assume these exact filters. frames = [0, 15, 200, 389] crop_filter = "crop=300:200:50:35:exact=1" @@ -124,6 +127,24 @@ def generate_nasa_13013_references(): ) +def generate_nasa_13013_references_resize(): + frames = [17, 230, 389] + # Note that the resize algorithm passed to flags is exposed to users, + # but bilinear is the default we use. + resize_filter = "scale=240:135:flags=bilinear" + for frame in frames: + generate_frame_by_index( + NASA_VIDEO, frame_index=frame, stream_index=3, filters=resize_filter + ) + + +def generate_nasa_13013_references(): + generate_nasa_13013_references_by_index() + generate_nasa_13013_references_by_timestamp() + generate_nasa_13013_references_crop() + generate_nasa_13013_references_resize() + + def generate_h265_video_references(): # This video was generated by running the following: # conda install -c conda-forge x265 diff --git a/test/resources/bt2020_10bit.mp4 b/test/resources/bt2020_10bit.mp4 new file mode 100644 index 000000000..4a0bcb46f Binary files /dev/null and b/test/resources/bt2020_10bit.mp4 differ diff --git a/test/resources/bt601_full_range.mp4 b/test/resources/bt601_full_range.mp4 new file mode 100644 index 000000000..b06dcd9f8 Binary files /dev/null and b/test/resources/bt601_full_range.mp4 differ diff --git a/test/resources/bt601_limited_range.mp4 b/test/resources/bt601_limited_range.mp4 new file mode 100644 index 000000000..a6cd4fad4 Binary files /dev/null and b/test/resources/bt601_limited_range.mp4 differ diff --git a/test/resources/nasa_13013.mp4.scale_240_135_flags_bilinear.stream3.frame000017.pt b/test/resources/nasa_13013.mp4.scale_240_135_flags_bilinear.stream3.frame000017.pt new file mode 100644 index 000000000..5da3e81fe Binary files /dev/null and b/test/resources/nasa_13013.mp4.scale_240_135_flags_bilinear.stream3.frame000017.pt differ diff --git a/test/resources/nasa_13013.mp4.scale_240_135_flags_bilinear.stream3.frame000230.pt b/test/resources/nasa_13013.mp4.scale_240_135_flags_bilinear.stream3.frame000230.pt new file mode 100644 index 000000000..5094e44da Binary files /dev/null and b/test/resources/nasa_13013.mp4.scale_240_135_flags_bilinear.stream3.frame000230.pt differ diff --git a/test/resources/nasa_13013.mp4.scale_240_135_flags_bilinear.stream3.frame000389.pt b/test/resources/nasa_13013.mp4.scale_240_135_flags_bilinear.stream3.frame000389.pt new file mode 100644 index 000000000..a15622389 Binary files /dev/null and b/test/resources/nasa_13013.mp4.scale_240_135_flags_bilinear.stream3.frame000389.pt differ diff --git a/test/resources/nasa_13013_rotated.mp4 b/test/resources/nasa_13013_rotated.mp4 new file mode 100644 index 000000000..749a1915c Binary files /dev/null and b/test/resources/nasa_13013_rotated.mp4 differ diff --git a/test/resources/nasa_13013_rotated.mp4.stream0.all_frames_info.json b/test/resources/nasa_13013_rotated.mp4.stream0.all_frames_info.json new file mode 100644 index 000000000..4c8c311e1 --- /dev/null +++ b/test/resources/nasa_13013_rotated.mp4.stream0.all_frames_info.json @@ -0,0 +1,70 @@ +[ + { + "duration_time": "0.033367", + "pts_time": "0.000000" + }, + { + "duration_time": "0.033367", + "pts_time": "0.033367" + }, + { + "duration_time": "0.033367", + "pts_time": "0.066733" + }, + { + "duration_time": "0.033367", + "pts_time": "0.100100" + }, + { + "duration_time": "0.033367", + "pts_time": "0.133467" + }, + { + "duration_time": "0.033367", + "pts_time": "0.166833" + }, + { + "duration_time": "0.033367", + "pts_time": "0.200200" + }, + { + "duration_time": "0.033367", + "pts_time": "0.233567" + }, + { + "duration_time": "0.033367", + "pts_time": "0.266933" + }, + { + "duration_time": "0.033367", + "pts_time": "0.300300" + }, + { + "duration_time": "0.033367", + "pts_time": "0.333667" + }, + { + "duration_time": "0.033367", + "pts_time": "0.367033" + }, + { + "duration_time": "0.033367", + "pts_time": "0.400400" + }, + { + "duration_time": "0.033367", + "pts_time": "0.433767" + }, + { + "duration_time": "0.033367", + "pts_time": "0.467133" + }, + { + "duration_time": "0.033367", + "pts_time": "0.500500" + }, + { + "duration_time": "0.033367", + "pts_time": "0.533867" + } +] diff --git a/test/resources/sine_16ch_s16.wav b/test/resources/sine_16ch_s16.wav new file mode 100644 index 000000000..556556c7c Binary files /dev/null and b/test/resources/sine_16ch_s16.wav differ diff --git a/test/resources/sine_16ch_s16.wav.stream0.all_frames_info.json b/test/resources/sine_16ch_s16.wav.stream0.all_frames_info.json new file mode 100644 index 000000000..cfc38f123 --- /dev/null +++ b/test/resources/sine_16ch_s16.wav.stream0.all_frames_info.json @@ -0,0 +1,66 @@ +[ + { + "pts_time": "0.000000", + "duration_time": "0.064000" + }, + { + "pts_time": "0.064000", + "duration_time": "0.064000" + }, + { + "pts_time": "0.128000", + "duration_time": "0.064000" + }, + { + "pts_time": "0.192000", + "duration_time": "0.064000" + }, + { + "pts_time": "0.256000", + "duration_time": "0.064000" + }, + { + "pts_time": "0.320000", + "duration_time": "0.064000" + }, + { + "pts_time": "0.384000", + "duration_time": "0.064000" + }, + { + "pts_time": "0.448000", + "duration_time": "0.064000" + }, + { + "pts_time": "0.512000", + "duration_time": "0.064000" + }, + { + "pts_time": "0.576000", + "duration_time": "0.064000" + }, + { + "pts_time": "0.640000", + "duration_time": "0.064000" + }, + { + "pts_time": "0.704000", + "duration_time": "0.064000" + }, + { + "pts_time": "0.768000", + "duration_time": "0.064000" + }, + { + "pts_time": "0.832000", + "duration_time": "0.064000" + }, + { + "pts_time": "0.896000", + "duration_time": "0.064000" + }, + { + "pts_time": "0.960000", + "duration_time": "0.040000" + } +] diff --git a/test/resources/test_non_zero_start.mp4 b/test/resources/test_non_zero_start.mp4 new file mode 100644 index 000000000..cd1815bc0 Binary files /dev/null and b/test/resources/test_non_zero_start.mp4 differ diff --git a/test/resources/test_non_zero_start.mp4.stream0.all_frames_info.json b/test/resources/test_non_zero_start.mp4.stream0.all_frames_info.json new file mode 100644 index 000000000..370d91af1 --- /dev/null +++ b/test/resources/test_non_zero_start.mp4.stream0.all_frames_info.json @@ -0,0 +1,50 @@ +[ + { + "pts_time": "8.333008", + "duration_time": "0.033333" + }, + { + "pts_time": "8.366341", + "duration_time": "0.033333" + }, + { + "pts_time": "8.399674", + "duration_time": "0.033333" + }, + { + "pts_time": "8.433008", + "duration_time": "0.033333" + }, + { + "pts_time": "8.466341", + "duration_time": "0.033333" + }, + { + "pts_time": "8.499674", + "duration_time": "0.033333" + }, + { + "pts_time": "8.533008", + "duration_time": "0.033333" + }, + { + "pts_time": "8.566341", + "duration_time": "0.033333" + }, + { + "pts_time": "8.599674", + "duration_time": "0.033333" + }, + { + "pts_time": "8.633008", + "duration_time": "0.033333" + }, + { + "pts_time": "8.666341", + "duration_time": "0.033333" + }, + { + "pts_time": "8.699674", + "duration_time": "0.033333" + } +] diff --git a/test/test_decoders.py b/test/test_decoders.py index 5e5028da6..489c2f936 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -6,31 +6,37 @@ import contextlib import gc -import json from functools import partial -from unittest.mock import patch import numpy import pytest import torch - -from torchcodec import _core, FrameBatch +from torchcodec import _core, ffmpeg_major_version, FrameBatch from torchcodec.decoders import ( AudioDecoder, AudioStreamMetadata, + get_nvdec_cache_capacity, set_cuda_backend, + set_nvdec_cache_capacity, VideoDecoder, VideoStreamMetadata, ) from torchcodec.decoders._decoder_utils import _get_cuda_backend +from torchcodec.decoders._wav_decoder import WavDecoder +from torchcodec.transforms import CenterCrop, RandomCrop, Resize from .utils import ( all_supported_devices, assert_frames_equal, + assert_tensor_close_on_at_least, AV1_VIDEO, + BT2020_LIMITED_RANGE_10BIT, + BT601_FULL_RANGE, + BT601_LIMITED_RANGE, BT709_FULL_RANGE, - cuda_version_used_for_building_torch, - get_ffmpeg_major_version, + cuda_devices, + get_ffmpeg_minor_version, + get_python_version, H264_10BITS, H265_10BITS, H265_VIDEO, @@ -40,13 +46,16 @@ NASA_AUDIO_MP3, NASA_AUDIO_MP3_44100, NASA_VIDEO, + NASA_VIDEO_ROTATED, needs_cuda, + needs_ffmpeg_cli, psnr, + SINE_16_CHANNEL_S16, SINE_MONO_S16, SINE_MONO_S32, SINE_MONO_S32_44100, SINE_MONO_S32_8000, - supports_approximate_mode, + TEST_NON_ZERO_START, TEST_SRC_2_720P, TEST_SRC_2_720P_H265, TEST_SRC_2_720P_MPEG4, @@ -118,17 +127,17 @@ def test_create_fails(self, Decoder): Decoder(123) # stream index that does not exist - with pytest.raises(ValueError, match="No valid stream found"): + with pytest.raises(ValueError, match="40 is not a valid stream"): Decoder(NASA_VIDEO.path, stream_index=40) # stream index that does exist, but it's not audio or video - with pytest.raises(ValueError, match="No valid stream found"): + with pytest.raises(ValueError, match=r"not (a|an) (video|audio) stream"): Decoder(NASA_VIDEO.path, stream_index=2) # user mistakenly forgets to specify binary reading when creating a file # like object from open() with pytest.raises(TypeError, match="binary reading?"): - Decoder(open(NASA_VIDEO.path, "r")) + Decoder(open(NASA_VIDEO.path)) class TestVideoDecoder: @@ -378,7 +387,7 @@ def test_getitem_slice(self, device, seek_mode): ] ) for sliced, ref in zip(all_frames, decoder): - if not (device == "cuda" and get_ffmpeg_major_version() == 4): + if not (device == "cuda" and ffmpeg_major_version == 4): # TODO: remove the "if". # See https://github.com/pytorch/torchcodec/issues/428 assert_frames_equal(sliced, ref) @@ -388,6 +397,31 @@ def test_device_instance(self): decoder = VideoDecoder(NASA_VIDEO.path, device=torch.device("cpu")) assert isinstance(decoder.metadata, VideoStreamMetadata) + @pytest.mark.parametrize( + "device_str", + [ + "cpu", + pytest.param("cuda", marks=pytest.mark.needs_cuda), + ], + ) + def test_device_none_default_device(self, device_str): + # VideoDecoder defaults to device=None, which should respect both + # torch.device() context manager and torch.set_default_device(). + + # Test with context manager + with torch.device(device_str): + decoder = VideoDecoder(NASA_VIDEO.path) + assert decoder[0].device.type == device_str + + # Test with set_default_device + original_device = torch.get_default_device() + try: + torch.set_default_device(device_str) + decoder = VideoDecoder(NASA_VIDEO.path) + assert decoder[0].device.type == device_str + finally: + torch.set_default_device(original_device) + @pytest.mark.parametrize("device", all_supported_devices()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_getitem_fails(self, device, seek_mode): @@ -599,7 +633,7 @@ def test_get_frames_at_fails(self, device, seek_mode): @pytest.mark.parametrize("device", all_supported_devices()) def test_get_frame_at_av1(self, device): - if device == "cuda" and get_ffmpeg_major_version() == 4: + if device == "cuda" and ffmpeg_major_version == 4: return if "cuda" in device and in_fbcode(): @@ -877,56 +911,6 @@ def test_get_frames_in_range_slice_indices_syntax(self, device, seek_mode): ).to(device) assert_frames_equal(frames387_None.data, reference_frame387_389) - @pytest.mark.parametrize("device", all_supported_devices()) - @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) - @patch("torchcodec._core._metadata._get_stream_json_metadata") - def test_get_frames_with_missing_num_frames_metadata( - self, mock_get_stream_json_metadata, device, seek_mode - ): - # Create a mock stream_dict to test that initializing VideoDecoder without - # num_frames_from_header and num_frames_from_content calculates num_frames - # using the average_fps and duration_seconds metadata. - mock_stream_dict = { - "averageFpsFromHeader": 29.97003, - "beginStreamSecondsFromContent": 0.0, - "beginStreamSecondsFromHeader": 0.0, - "bitRate": 128783.0, - "codec": "h264", - "durationSecondsFromHeader": 13.013, - "endStreamSecondsFromContent": 13.013, - "width": 480, - "height": 270, - "mediaType": "video", - "numFramesFromHeader": None, - "numFramesFromContent": None, - } - # Set the return value of the mock to be the mock_stream_dict - mock_get_stream_json_metadata.return_value = json.dumps(mock_stream_dict) - - decoder, device = make_video_decoder( - NASA_VIDEO.path, - stream_index=3, - device=device, - seek_mode=seek_mode, - ) - - assert decoder.metadata.num_frames_from_header is None - assert decoder.metadata.num_frames_from_content is None - assert decoder.metadata.duration_seconds is not None - assert decoder.metadata.average_fps is not None - assert decoder.metadata.num_frames == int( - decoder.metadata.duration_seconds * decoder.metadata.average_fps - ) - assert len(decoder) == 390 - - # Test get_frames_in_range Python logic which uses the num_frames metadata mocked earlier. - # The frame is read at the C++ level. - ref_frames9 = NASA_VIDEO.get_frame_data_by_range( - start=9, stop=10, stream_index=3 - ).to(device) - frames9 = decoder.get_frames_in_range(start=9, stop=10) - assert_frames_equal(ref_frames9, frames9.data) - @pytest.mark.parametrize("dimension_order", ["NCHW", "NHWC"]) @pytest.mark.parametrize( "frame_getter", @@ -1123,6 +1107,216 @@ def test_get_frames_by_pts_in_range_fails(self, device, seek_mode): with pytest.raises(ValueError, match="Invalid stop seconds"): frame = decoder.get_frames_played_in_range(0, 23) # noqa + @pytest.mark.parametrize("device", all_supported_devices()) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_get_frames_played_in_range_with_fps(self, device, seek_mode): + decoder, _ = make_video_decoder( + NASA_VIDEO.path, device=device, seek_mode=seek_mode + ) + + source_fps = decoder.metadata.average_fps + duration_seconds = 1.0 + start_seconds = decoder.get_frame_at(0).pts_seconds + frame1_pts = decoder.get_frame_at(1).pts_seconds + stop_seconds = start_seconds + duration_seconds + + # Test downsampling: request lower fps than source + fps_low = 5 + frames_low_fps = decoder.get_frames_played_in_range( + start_seconds, stop_seconds, fps=fps_low + ) + expected_frames_low = round(duration_seconds * fps_low) + assert len(frames_low_fps) == expected_frames_low + # First output frame should be frame 0 + frame0_data = decoder.get_frame_at(0).data + torch.testing.assert_close(frames_low_fps.data[0], frame0_data, atol=0, rtol=0) + # Second output frame should NOT be frame 1 (we're downsampling) + frame1_data = decoder.get_frame_at(1).data + assert not torch.equal(frames_low_fps.data[1], frame1_data) + + # Test upsampling: request higher fps than source (frames should be duplicated) + # Request 3x the source fps for a single frame's duration + fps_high = int(source_fps * 3) + frames_high_fps = decoder.get_frames_played_in_range( + start_seconds, frame1_pts, fps=fps_high + ) + # All frames should be duplicates of frame 0 since we're within frame 0's display time + frame_duration = frame1_pts - start_seconds + expected_frames_high = round(frame_duration * fps_high) + assert len(frames_high_fps) == expected_frames_high + + # All duplicated frames should have the same content as frame 0 + frame0_data = decoder.get_frame_at(0).data + if not (device == "cuda" and ffmpeg_major_version == 4): + for i in range(len(frames_high_fps)): + torch.testing.assert_close( + frames_high_fps.data[i], frame0_data, atol=0, rtol=0 + ) + + # Test that fps=None returns the original behavior (same as not passing fps) + frames_no_fps = decoder.get_frames_played_in_range(start_seconds, stop_seconds) + frames_none_fps = decoder.get_frames_played_in_range( + start_seconds, stop_seconds, fps=None + ) + assert len(frames_no_fps) == len(frames_none_fps) + if not (device == "cuda" and ffmpeg_major_version == 4): + torch.testing.assert_close( + frames_no_fps.data, frames_none_fps.data, atol=0, rtol=0 + ) + + @pytest.mark.parametrize("device", all_supported_devices()) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_get_frames_played_in_range_with_fps_fails(self, device, seek_mode): + decoder, _ = make_video_decoder( + NASA_VIDEO.path, device=device, seek_mode=seek_mode + ) + + start_seconds = decoder.get_frame_at(0).pts_seconds + stop_seconds = start_seconds + 1.0 + + with pytest.raises(RuntimeError, match="fps must be positive"): + decoder.get_frames_played_in_range(start_seconds, stop_seconds, fps=0) + + with pytest.raises(RuntimeError, match="fps must be positive"): + decoder.get_frames_played_in_range(start_seconds, stop_seconds, fps=-10) + + @pytest.mark.parametrize("fps", [5.0, 15.0, 24.0, 29.97, 30.1, 60.0]) + @pytest.mark.parametrize("full_video", [False, True]) + def test_get_frames_played_in_range_fps_matches_torchvision(self, fps, full_video): + """Test that TorchCodec's fps output matches torchvision's resampling logic.""" + decoder = VideoDecoder(NASA_VIDEO.path) + + if full_video: + start_seconds = decoder.metadata.begin_stream_seconds + stop_seconds = decoder.metadata.end_stream_seconds + else: + start_seconds = 0.0 + stop_seconds = start_seconds + 1.0 + + # Get resampled frames using our fps feature + tc_frames_batch = decoder.get_frames_played_in_range( + start_seconds=start_seconds, + stop_seconds=stop_seconds, + fps=fps, + ) + + # Get all source frames in the range + all_source_frames = decoder.get_frames_played_in_range( + start_seconds=start_seconds, + stop_seconds=stop_seconds, + ) + + # Compute expected indices using torchvision's resampling logic: + # https://github.com/pytorch/vision/blob/1e53952f57462e4c28103835cf1f9e504dbea84b/torchvision/datasets/video_utils.py#L278 + # For each output frame i, select source frame at index floor(i * step) + # where step = original_fps / target_fps + original_fps = decoder.metadata.average_fps + step = original_fps / fps + expected_indices = ( + (torch.arange(len(tc_frames_batch), dtype=torch.float32) * step) + .floor() + .to(torch.int64) + ) + expected_frames = all_source_frames.data[expected_indices] + + torch.testing.assert_close( + tc_frames_batch.data, + expected_frames, + rtol=0, + atol=0, + ) + + @pytest.mark.parametrize("device", all_supported_devices()) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_get_all_frames(self, device, seek_mode): + """Test that get_all_frames returns all frames and is equivalent to get_frames_played_in_range.""" + decoder, _ = make_video_decoder( + NASA_VIDEO.path, device=device, seek_mode=seek_mode + ) + + all_frames = decoder.get_all_frames() + + assert len(all_frames) == len(decoder) + + frames_in_range = decoder.get_frames_played_in_range( + start_seconds=decoder.metadata.begin_stream_seconds, + stop_seconds=decoder.metadata.end_stream_seconds, + ) + assert len(all_frames) == len(frames_in_range) + # Use strict bitwise equality, except for FFmpeg 4 + CUDA FFmpeg + # interface which has known issues (see #428) + if not (device == "cuda" and ffmpeg_major_version == 4): + torch.testing.assert_close( + all_frames.data, frames_in_range.data, atol=0, rtol=0 + ) + + fps = 10.0 + all_frames_with_fps = decoder.get_all_frames(fps=fps) + frames_in_range_with_fps = decoder.get_frames_played_in_range( + start_seconds=decoder.metadata.begin_stream_seconds, + stop_seconds=decoder.metadata.end_stream_seconds, + fps=fps, + ) + assert len(all_frames_with_fps) == len(frames_in_range_with_fps) + # Use strict bitwise equality, except for FFmpeg 4 + CUDA FFmpeg + # interface which has known issues (see #428) + if not (device == "cuda" and ffmpeg_major_version == 4): + torch.testing.assert_close( + all_frames_with_fps.data, frames_in_range_with_fps.data, atol=0, rtol=0 + ) + + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_non_zero_start_pts(self, seek_mode): + """Test that frame retrieval methods return correct PTS values for videos with non-zero start time. + + This is a non-regression test for https://github.com/meta-pytorch/torchcodec/pull/1209 + """ + decoder = VideoDecoder(TEST_NON_ZERO_START.path, seek_mode=seek_mode) + + # Verify the video has a non-zero start time + assert decoder.metadata.begin_stream_seconds > 0 + expected_start_time = TEST_NON_ZERO_START.get_frame_info(0).pts_seconds + assert expected_start_time == pytest.approx(8.333, rel=1e-3) + + frame0 = decoder.get_frame_at(0) + assert frame0.pts_seconds == pytest.approx(expected_start_time, rel=1e-3) + + frame1 = decoder.get_frame_at(1) + expected_frame1_pts = TEST_NON_ZERO_START.get_frame_info(1).pts_seconds + assert frame1.pts_seconds == pytest.approx(expected_frame1_pts, rel=1e-3) + + frames = decoder.get_frames_at([0, 1, 2]) + for i, expected_idx in enumerate([0, 1, 2]): + expected_pts = TEST_NON_ZERO_START.get_frame_info(expected_idx).pts_seconds + assert frames.pts_seconds[i].item() == pytest.approx(expected_pts, rel=1e-3) + + frame_at_start = decoder.get_frame_played_at(expected_start_time) + assert frame_at_start.pts_seconds == pytest.approx( + expected_start_time, rel=1e-3 + ) + + frames_range = decoder.get_frames_in_range(0, 3) + for i in range(3): + expected_pts = TEST_NON_ZERO_START.get_frame_info(i).pts_seconds + assert frames_range.pts_seconds[i].item() == pytest.approx( + expected_pts, rel=1e-3 + ) + + # Use the decoder's own PTS value to avoid floating point precision issues + # between ffprobe's PTS (in JSON) and the decoder's computed PTS + frame3 = decoder.get_frame_at(3) + stop_pts = frame3.pts_seconds + frames_pts_range = decoder.get_frames_played_in_range( + expected_start_time, stop_pts + ) + # Should get frames 0, 1, 2 (stop is exclusive) + assert len(frames_pts_range) == 3 + for i in range(3): + expected_pts = TEST_NON_ZERO_START.get_frame_info(i).pts_seconds + assert frames_pts_range.pts_seconds[i].item() == pytest.approx( + expected_pts, rel=1e-3 + ) + @pytest.mark.parametrize("device", all_supported_devices()) def test_get_key_frame_indices(self, device): decoder, _ = make_video_decoder( @@ -1172,8 +1366,14 @@ def test_get_key_frame_indices(self, device): key_frame_indices, h265_reference_key_frame_indices, atol=0, rtol=0 ) + # TODO investigate why this is failing from the nightlies of Dec 09 2025. + @pytest.mark.skip(reason="TODO investigate") # TODO investigate why this fails internally. @pytest.mark.skipif(in_fbcode(), reason="Compile test fails internally.") + @pytest.mark.skipif( + get_python_version() >= (3, 14), + reason="torch.compile is not supported on Python 3.14+", + ) @pytest.mark.parametrize("device", all_supported_devices()) def test_compile(self, device): decoder, device = make_video_decoder(NASA_VIDEO.path, device=device) @@ -1210,7 +1410,15 @@ def get_some_frames(decoder): assert_frames_equal(ref_frame3, frames[1].data) assert_frames_equal(ref_frame5, frames[2].data) + # The test video we have is from + # https://huggingface.co/datasets/raushan-testing-hf/videos-test/blob/main/sample_video_2.avi + # We can't check it into the repo due to potential licensing issues, so + # we have to unconditionally skip this test. + # TODO: encode a video with no pts values to unskip this test. Couldn't + # find a way to do that with FFmpeg's CLI, but this should be doable + # once we have our own video encoder. @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + @pytest.mark.skip(reason="TODO: Need video with no pts values.") def test_pts_to_dts_fallback(self, seek_mode): # Non-regression test for # https://github.com/pytorch/torchcodec/issues/677 and @@ -1218,16 +1426,6 @@ def test_pts_to_dts_fallback(self, seek_mode): # More accurately, this is a non-regression test for videos which do # *not* specify pts values (all pts values are N/A and set to # INT64_MIN), but specify *dts* value - which we fallback to. - # - # The test video we have is from - # https://huggingface.co/datasets/raushan-testing-hf/videos-test/blob/main/sample_video_2.avi - # We can't check it into the repo due to potential licensing issues, so - # we have to unconditionally skip this test.# - # TODO: encode a video with no pts values to unskip this test. Couldn't - # find a way to do that with FFmpeg's CLI, but this should be doable - # once we have our own video encoder. - pytest.skip(reason="TODO: Need video with no pts values.") - path = "/home/nicolashug/Downloads/sample_video_2.avi" decoder = VideoDecoder(path, seek_mode=seek_mode) metadata = decoder.metadata @@ -1270,12 +1468,51 @@ def test_full_and_studio_range_bt709_video(self, asset): gpu_frame = decoder_gpu.get_frame_at(frame_index).data.cpu() cpu_frame = decoder_cpu.get_frame_at(frame_index).data - if cuda_version_used_for_building_torch() >= (13, 0): - torch.testing.assert_close(gpu_frame, cpu_frame, rtol=0, atol=3) - elif cuda_version_used_for_building_torch() >= (12, 9): - torch.testing.assert_close(gpu_frame, cpu_frame, rtol=0, atol=2) - elif cuda_version_used_for_building_torch() == (12, 8): - assert psnr(gpu_frame, cpu_frame) > 20 + torch.testing.assert_close(gpu_frame, cpu_frame, rtol=0, atol=3) + + @needs_cuda + def test_bt2020_10bit_video(self): + # Test ensuring result consistency between CPU and beta CUDA (NVDEC) + # decoder on a BT.2020 10-bit video (limited range). This is a + # non-regression test for BT.2020 color conversion support. + # + # bt2020_10bit.mp4 is a BT.2020 limited range 10-bit HEVC video: + # color_space=bt2020nc, color_range=tv, pix_fmt=yuv420p10le + # + # NVDEC decodes 10-bit natively (converting to 8-bit NV12), then our + # BT.2020 color twist matrix handles the YUV->RGB conversion. + # + # TODO investigate CPU vs BetaCUDA mismatch on BT.2020 10-bit. + # See PR #1267 for details. + asset = BT2020_LIMITED_RANGE_10BIT + + with set_cuda_backend("beta"): + decoder_gpu = VideoDecoder(asset.path, device="cuda") + decoder_cpu = VideoDecoder(asset.path, device="cpu") + + for frame_index in (0, 10, 20, 5): + gpu_frame = decoder_gpu.get_frame_at(frame_index).data.cpu() + cpu_frame = decoder_cpu.get_frame_at(frame_index).data + + assert_tensor_close_on_at_least(gpu_frame, cpu_frame, percentage=90, atol=3) + + @needs_cuda + @pytest.mark.parametrize( + "asset", + (BT601_FULL_RANGE, BT601_LIMITED_RANGE), + ) + def test_bt601_colorspace(self, asset): + # Test ensuring result consistency between CPU and beta CUDA (NVDEC) + # decoder on BT.601 videos with full and limited range. + with set_cuda_backend("beta"): + decoder_gpu = VideoDecoder(asset.path, device="cuda") + decoder_cpu = VideoDecoder(asset.path, device="cpu") + + for frame_index in (0, 10, 20, 5): + gpu_frame = decoder_gpu.get_frame_at(frame_index).data.cpu() + cpu_frame = decoder_cpu.get_frame_at(frame_index).data + + torch.testing.assert_close(gpu_frame, cpu_frame, rtol=0, atol=3) @needs_cuda def test_10bit_gpu_fallsback_to_cpu(self): @@ -1314,16 +1551,6 @@ def test_10bit_videos(self, device, asset): # This just validates that we can decode 10-bit videos. # TODO validate against the ref that the decoded frames are correct - if device == "cuda:beta" and asset is H264_10BITS: - # This fails on the BETA interface with: - # - # RuntimeError: Codec configuration not supported on this GPU. - # Codec: 4, chroma format: 1, bit depth: 10 - # - # It works on the ffmpeg interface because FFmpeg fallsback to the - # CPU, while the BETA interface doesn't. - pytest.skip("Asset not supported by NVDEC") - decoder, _ = make_video_decoder(asset.path, device=device) decoder.get_frame_at(10) @@ -1339,10 +1566,7 @@ def setup_frame_mappings(tmp_path, file, stream_index): # Return the custom frame mappings as a JSON string return custom_frame_mappings - @pytest.mark.skipif( - in_fbcode(), - reason="ffprobe not available internally", - ) + @needs_ffmpeg_cli @pytest.mark.parametrize("device", all_supported_devices()) @pytest.mark.parametrize("stream_index", [0, 3]) @pytest.mark.parametrize( @@ -1359,7 +1583,7 @@ def test_custom_frame_mappings_json_and_bytes( # Optionally open the custom frame mappings file if it is a file path # or use a null context if it is a string. with ( - open(custom_frame_mappings, "r") + open(custom_frame_mappings) if hasattr(custom_frame_mappings, "read") else contextlib.nullcontext() ) as custom_frame_mappings: @@ -1389,10 +1613,7 @@ def test_custom_frame_mappings_json_and_bytes( ), ) - @pytest.mark.skipif( - in_fbcode(), - reason="ffprobe not available internally", - ) + @needs_ffmpeg_cli @pytest.mark.parametrize("device", all_supported_devices()) @pytest.mark.parametrize( "custom_frame_mappings,expected_match", @@ -1434,7 +1655,7 @@ def test_custom_frame_mappings_init_fails_invalid_json(self, tmp_path, device): f.write("invalid input") # Test both file object and string - with open(invalid_json_path, "r") as file_obj: + with open(invalid_json_path) as file_obj: for custom_frame_mappings in [ file_obj, file_obj.read(), @@ -1470,8 +1691,8 @@ def test_get_frames_at_tensor_indices(self): # assert_tensor_close_on_at_least or something like that. # - unskip equality assertion checks for MPEG4 asset. The frames are decoded # fine, it's the color conversion that's different. The frame from the - # BETA interface is assumed to be 701 while the one from the default - # interface is 601. + # BETA interface is mapped to 709 by the matrix coefficient using NVCUVID + # while the one from the default interface is 601. @needs_cuda @pytest.mark.parametrize( @@ -1481,7 +1702,12 @@ def test_get_frames_at_tensor_indices(self): TEST_SRC_2_720P, BT709_FULL_RANGE, TEST_SRC_2_720P_H265, - AV1_VIDEO, + pytest.param( + AV1_VIDEO, + marks=pytest.mark.skipif( + in_fbcode(), reason="AV1 CUDA not supported internally" + ), + ), TEST_SRC_2_720P_VP9, TEST_SRC_2_720P_VP8, TEST_SRC_2_720P_MPEG4, @@ -1492,12 +1718,6 @@ def test_get_frames_at_tensor_indices(self): def test_beta_cuda_interface_get_frame_at( self, asset, contiguous_indices, seek_mode ): - if seek_mode == "approximate" and not supports_approximate_mode(asset): - pytest.skip("asset doesn't work with approximate mode") - - if in_fbcode() and asset is AV1_VIDEO: - pytest.skip("AV1 CUDA not supported internally") - ref_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode) with set_cuda_backend("beta"): beta_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode) @@ -1513,7 +1733,7 @@ def test_beta_cuda_interface_get_frame_at( ref_frame = ref_decoder.get_frame_at(frame_index) beta_frame = beta_decoder.get_frame_at(frame_index) # TODONVDEC P1 see above - if get_ffmpeg_major_version() > 4 and asset is not TEST_SRC_2_720P_MPEG4: + if ffmpeg_major_version > 4 and asset is not TEST_SRC_2_720P_MPEG4: torch.testing.assert_close( beta_frame.data, ref_frame.data, rtol=0, atol=0 ) @@ -1529,7 +1749,12 @@ def test_beta_cuda_interface_get_frame_at( TEST_SRC_2_720P, BT709_FULL_RANGE, TEST_SRC_2_720P_H265, - AV1_VIDEO, + pytest.param( + AV1_VIDEO, + marks=pytest.mark.skipif( + in_fbcode(), reason="AV1 CUDA not supported internally" + ), + ), TEST_SRC_2_720P_VP9, TEST_SRC_2_720P_VP8, TEST_SRC_2_720P_MPEG4, @@ -1540,11 +1765,6 @@ def test_beta_cuda_interface_get_frame_at( def test_beta_cuda_interface_get_frames_at( self, asset, contiguous_indices, seek_mode ): - if seek_mode == "approximate" and not supports_approximate_mode(asset): - pytest.skip("asset doesn't work with approximate mode") - if in_fbcode() and asset is AV1_VIDEO: - pytest.skip("AV1 CUDA not supported internally") - ref_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode) with set_cuda_backend("beta"): beta_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode) @@ -1560,7 +1780,7 @@ def test_beta_cuda_interface_get_frames_at( ref_frames = ref_decoder.get_frames_at(indices) beta_frames = beta_decoder.get_frames_at(indices) # TODONVDEC P1 see above - if get_ffmpeg_major_version() > 4 and asset is not TEST_SRC_2_720P_MPEG4: + if ffmpeg_major_version > 4 and asset is not TEST_SRC_2_720P_MPEG4: torch.testing.assert_close( beta_frames.data, ref_frames.data, rtol=0, atol=0 ) @@ -1577,7 +1797,12 @@ def test_beta_cuda_interface_get_frames_at( TEST_SRC_2_720P, BT709_FULL_RANGE, TEST_SRC_2_720P_H265, - AV1_VIDEO, + pytest.param( + AV1_VIDEO, + marks=pytest.mark.skipif( + in_fbcode(), reason="AV1 CUDA not supported internally" + ), + ), TEST_SRC_2_720P_VP9, TEST_SRC_2_720P_VP8, TEST_SRC_2_720P_MPEG4, @@ -1585,11 +1810,6 @@ def test_beta_cuda_interface_get_frames_at( ) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_beta_cuda_interface_get_frame_played_at(self, asset, seek_mode): - if seek_mode == "approximate" and not supports_approximate_mode(asset): - pytest.skip("asset doesn't work with approximate mode") - if in_fbcode() and asset is AV1_VIDEO: - pytest.skip("AV1 CUDA not supported internally") - ref_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode) with set_cuda_backend("beta"): beta_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode) @@ -1603,7 +1823,7 @@ def test_beta_cuda_interface_get_frame_played_at(self, asset, seek_mode): ref_frame = ref_decoder.get_frame_played_at(pts) beta_frame = beta_decoder.get_frame_played_at(pts) # TODONVDEC P1 see above - if get_ffmpeg_major_version() > 4 and asset is not TEST_SRC_2_720P_MPEG4: + if ffmpeg_major_version > 4 and asset is not TEST_SRC_2_720P_MPEG4: torch.testing.assert_close( beta_frame.data, ref_frame.data, rtol=0, atol=0 ) @@ -1619,7 +1839,12 @@ def test_beta_cuda_interface_get_frame_played_at(self, asset, seek_mode): TEST_SRC_2_720P, BT709_FULL_RANGE, TEST_SRC_2_720P_H265, - AV1_VIDEO, + pytest.param( + AV1_VIDEO, + marks=pytest.mark.skipif( + in_fbcode(), reason="AV1 CUDA not supported internally" + ), + ), TEST_SRC_2_720P_VP9, TEST_SRC_2_720P_VP8, TEST_SRC_2_720P_MPEG4, @@ -1627,11 +1852,6 @@ def test_beta_cuda_interface_get_frame_played_at(self, asset, seek_mode): ) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_beta_cuda_interface_get_frames_played_at(self, asset, seek_mode): - if seek_mode == "approximate" and not supports_approximate_mode(asset): - pytest.skip("asset doesn't work with approximate mode") - if in_fbcode() and asset is AV1_VIDEO: - pytest.skip("AV1 CUDA not supported internally") - ref_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode) with set_cuda_backend("beta"): beta_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode) @@ -1645,7 +1865,7 @@ def test_beta_cuda_interface_get_frames_played_at(self, asset, seek_mode): ref_frames = ref_decoder.get_frames_played_at(timestamps) beta_frames = beta_decoder.get_frames_played_at(timestamps) # TODONVDEC P1 see above - if get_ffmpeg_major_version() > 4 and asset is not TEST_SRC_2_720P_MPEG4: + if ffmpeg_major_version > 4 and asset is not TEST_SRC_2_720P_MPEG4: torch.testing.assert_close( beta_frames.data, ref_frames.data, rtol=0, atol=0 ) @@ -1662,7 +1882,12 @@ def test_beta_cuda_interface_get_frames_played_at(self, asset, seek_mode): TEST_SRC_2_720P, BT709_FULL_RANGE, TEST_SRC_2_720P_H265, - AV1_VIDEO, + pytest.param( + AV1_VIDEO, + marks=pytest.mark.skipif( + in_fbcode(), reason="AV1 CUDA not supported internally" + ), + ), TEST_SRC_2_720P_VP9, TEST_SRC_2_720P_VP8, TEST_SRC_2_720P_MPEG4, @@ -1670,11 +1895,6 @@ def test_beta_cuda_interface_get_frames_played_at(self, asset, seek_mode): ) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_beta_cuda_interface_backwards(self, asset, seek_mode): - if seek_mode == "approximate" and not supports_approximate_mode(asset): - pytest.skip("asset doesn't work with approximate mode") - if in_fbcode() and asset is AV1_VIDEO: - pytest.skip("AV1 CUDA not supported internally") - ref_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode) with set_cuda_backend("beta"): beta_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode) @@ -1692,7 +1912,7 @@ def test_beta_cuda_interface_backwards(self, asset, seek_mode): ref_frame = ref_decoder.get_frame_at(frame_index) beta_frame = beta_decoder.get_frame_at(frame_index) # TODONVDEC P1 see above - if get_ffmpeg_major_version() > 4 and asset is not TEST_SRC_2_720P_MPEG4: + if ffmpeg_major_version > 4 and asset is not TEST_SRC_2_720P_MPEG4: torch.testing.assert_close( beta_frame.data, ref_frame.data, rtol=0, atol=0 ) @@ -1710,26 +1930,31 @@ def test_beta_cuda_interface_cpu_fallback(self): # to the CPU path, too. ref_dec = VideoDecoder(H265_VIDEO.path, device="cuda") - ref_frames = ref_dec.get_frame_at(0) - assert ( - _core._get_backend_details(ref_dec._decoder) - == "FFmpeg CUDA Device Interface. Using CPU fallback." - ) + + # Before accessing any frames, status should be unknown + assert not ref_dec.cpu_fallback.status_known + + ref_frame = ref_dec.get_frame_at(0) + + assert "FFmpeg CUDA" in str(ref_dec.cpu_fallback) + assert ref_dec.cpu_fallback.status_known + assert ref_dec.cpu_fallback with set_cuda_backend("beta"): beta_dec = VideoDecoder(H265_VIDEO.path, device="cuda") - assert ( - _core._get_backend_details(beta_dec._decoder) - == "Beta CUDA Device Interface. Using CPU fallback." - ) + assert "Beta CUDA" in str(beta_dec.cpu_fallback) + # For beta interface, status is known immediately + assert beta_dec.cpu_fallback.status_known + assert beta_dec.cpu_fallback + beta_frame = beta_dec.get_frame_at(0) - assert psnr(ref_frames.data, beta_frame.data) > 25 + assert psnr(ref_frame.data, beta_frame.data) > 25 @needs_cuda def test_beta_cuda_interface_error(self): - with pytest.raises(RuntimeError, match="Invalid device string"): + with pytest.raises(RuntimeError, match="torch_parse_device_string"): VideoDecoder(NASA_VIDEO.path, device="cuda:0:bad_variant") @needs_cuda @@ -1753,7 +1978,7 @@ def test_set_cuda_backend(self): # Check that the default is the ffmpeg backend assert _get_cuda_backend() == "ffmpeg" dec = VideoDecoder(H265_VIDEO.path, device="cuda") - assert _core._get_backend_details(dec._decoder).startswith("FFmpeg CUDA") + assert "FFmpeg CUDA" in str(dec.cpu_fallback) # Check the setting "beta" effectively uses the BETA backend. # We also show that the affects decoder creation only. When the decoder @@ -1762,22 +1987,268 @@ def test_set_cuda_backend(self): with set_cuda_backend("beta"): dec = VideoDecoder(H265_VIDEO.path, device="cuda") assert _get_cuda_backend() == "ffmpeg" - assert _core._get_backend_details(dec._decoder).startswith("Beta CUDA") + assert "Beta CUDA" in str(dec.cpu_fallback) with set_cuda_backend("ffmpeg"): - assert _core._get_backend_details(dec._decoder).startswith("Beta CUDA") + assert "Beta CUDA" in str(dec.cpu_fallback) # Hacky way to ensure passing "cuda:1" is supported by both backends. We # just check that there's an error when passing cuda:N where N is too # high. bad_device_number = torch.cuda.device_count() + 1 for backend in ("ffmpeg", "beta"): - with pytest.raises(RuntimeError, match="invalid device ordinal"): + with pytest.raises(RuntimeError, match="torch_call_dispatcher"): with set_cuda_backend(backend): VideoDecoder(H265_VIDEO.path, device=f"cuda:{bad_device_number}") + @contextlib.contextmanager + def restore_nvdec_cache_capacity(self): + try: + original = get_nvdec_cache_capacity() + yield + finally: + set_nvdec_cache_capacity(original) + assert get_nvdec_cache_capacity() == original + + def test_nvdec_cache_capacity(self): + with self.restore_nvdec_cache_capacity(): + set_nvdec_cache_capacity(42) + assert get_nvdec_cache_capacity() == 42 + + set_nvdec_cache_capacity(0) + assert get_nvdec_cache_capacity() == 0 + + set_nvdec_cache_capacity(1) + assert get_nvdec_cache_capacity() == 1 + + with pytest.raises( + RuntimeError, match="NVDEC cache capacity must be non-negative" + ): + set_nvdec_cache_capacity(-1) + + # Capacity is unchanged after the failed call above. + assert get_nvdec_cache_capacity() == 1 + + @needs_cuda + def test_nvdec_cache_capacity_eviction(self): + + def create_decoder(): + with set_cuda_backend("beta"): + dec = VideoDecoder(NASA_VIDEO.path, device="cuda") + dec[0] + del dec + gc.collect() + + with self.restore_nvdec_cache_capacity(): + assert _core._get_nvdec_cache_size(device_index=0) == 0 + + # Create decoder, it should be in the cache + create_decoder() + assert _core._get_nvdec_cache_size(device_index=0) == 1 + + # Set capacity to 1, decoder should still be there + set_nvdec_cache_capacity(1) + assert _core._get_nvdec_cache_size(device_index=0) == 1 + # Set capacity to 0, this should evict it + set_nvdec_cache_capacity(0) + assert _core._get_nvdec_cache_size(device_index=0) == 0 + + # Create a new decoder, it's not cached since capacity is 0 + create_decoder() + assert _core._get_nvdec_cache_size(device_index=0) == 0 + + def test_cpu_fallback_no_fallback_on_cpu_device(self): + """Test that CPU device doesn't trigger fallback (it's not a fallback scenario).""" + decoder = VideoDecoder(NASA_VIDEO.path, device="cpu") + + assert decoder.cpu_fallback.status_known + _ = decoder[0] + + assert not decoder.cpu_fallback + assert "No fallback required" in str(decoder.cpu_fallback) + + @pytest.mark.parametrize("dimension_order", ["NCHW", "NHWC"]) + @pytest.mark.parametrize( + # We are skipping over cuda because we do not support rotation metadata + # for the FFmpeg CUDA interface. + "device", + ("cpu", pytest.param("cuda:beta", marks=pytest.mark.needs_cuda)), + ) + def test_rotation_applied_to_frames(self, dimension_order, device): + """Test that rotation is correctly applied to decoded frames. + + Compares frames from NASA_VIDEO_ROTATED (which has 90-degree rotation + metadata) with manually rotated frames from NASA_VIDEO. + Tests all decoding methods to ensure rotation is applied consistently. + """ + decoder, _ = make_video_decoder( + NASA_VIDEO.path, + device=device, + stream_index=NASA_VIDEO.default_stream_index, + dimension_order=dimension_order, + ) + decoder_rotated, _ = make_video_decoder( + NASA_VIDEO_ROTATED.path, + device=device, + stream_index=NASA_VIDEO_ROTATED.default_stream_index, + dimension_order=dimension_order, + ) + + # Rotation dims for single frame (CHW or HWC) and batch (NCHW or NHWC) + # Rotation dims are (H, W) dimensions for each format + frame_rot_dims = (1, 2) if dimension_order == "NCHW" else (0, 1) # CHW vs HWC + batch_rot_dims = (2, 3) if dimension_order == "NCHW" else (1, 2) # NCHW vs NHWC + + # Test __getitem__ / get_frame_at (single frame by index) + for idx in [0, 5, 10]: + frame = decoder[idx] + frame_rotated = decoder_rotated[idx] + expected = torch.rot90(frame, k=1, dims=frame_rot_dims) + torch.testing.assert_close(expected, frame_rotated, atol=0, rtol=0) + + # Test get_frames_at (multiple frames by indices) + indices = [0, 5, 10] + frames = decoder.get_frames_at(indices) + frames_rotated = decoder_rotated.get_frames_at(indices) + expected = torch.rot90(frames.data, k=1, dims=batch_rot_dims) + torch.testing.assert_close(expected, frames_rotated.data, atol=0, rtol=0) + + # Test get_frames_in_range (frames by index range) + frames_range = decoder.get_frames_in_range(start=0, stop=6, step=2) + frames_range_rotated = decoder_rotated.get_frames_in_range( + start=0, stop=6, step=2 + ) + expected = torch.rot90(frames_range.data, k=1, dims=batch_rot_dims) + torch.testing.assert_close(expected, frames_range_rotated.data, atol=0, rtol=0) + + # Test get_frame_played_at (single frame by timestamp) + pts = decoder_rotated.metadata.begin_stream_seconds + frame_at_pts = decoder.get_frame_played_at(pts) + frame_at_pts_rotated = decoder_rotated.get_frame_played_at(pts) + expected = torch.rot90(frame_at_pts.data, k=1, dims=frame_rot_dims) + torch.testing.assert_close(expected, frame_at_pts_rotated.data, atol=0, rtol=0) + + # Test get_frames_played_at (multiple frames by timestamps) + pts_list = [ + decoder_rotated.metadata.begin_stream_seconds, + decoder_rotated.metadata.begin_stream_seconds + 0.15, + ] + frames_at_pts = decoder.get_frames_played_at(pts_list) + frames_at_pts_rotated = decoder_rotated.get_frames_played_at(pts_list) + expected = torch.rot90(frames_at_pts.data, k=1, dims=batch_rot_dims) + torch.testing.assert_close(expected, frames_at_pts_rotated.data, atol=0, rtol=0) + + # Test get_frames_played_in_range (frames by timestamp range) + start_seconds = decoder_rotated.metadata.begin_stream_seconds + stop_seconds = start_seconds + 0.2 + frames_in_range = decoder.get_frames_played_in_range( + start_seconds=start_seconds, stop_seconds=stop_seconds + ) + frames_in_range_rotated = decoder_rotated.get_frames_played_in_range( + start_seconds=start_seconds, stop_seconds=stop_seconds + ) + expected = torch.rot90(frames_in_range.data, k=1, dims=batch_rot_dims) + torch.testing.assert_close( + expected, frames_in_range_rotated.data, atol=0, rtol=0 + ) + + # Test get_all_frames (all frames in video) + # Note: NASA_VIDEO_ROTATED has fewer frames than NASA_VIDEO, so we compare + # the first N frames where N is the number of frames in the rotated video + all_frames = decoder.get_all_frames() + all_frames_rotated = decoder_rotated.get_all_frames() + num_frames_rotated = all_frames_rotated.data.shape[0] + expected = torch.rot90( + all_frames.data[:num_frames_rotated], k=1, dims=batch_rot_dims + ) + torch.testing.assert_close(expected, all_frames_rotated.data, atol=0, rtol=0) + + @pytest.mark.parametrize( + "desired_H, desired_W", + [ + (100, 150), + (150, 100), + (100, 100), + ], + ) + @pytest.mark.parametrize("TransformClass", [Resize, CenterCrop, RandomCrop]) + def test_rotation_with_transform(self, TransformClass, desired_H, desired_W): + """Test that transforms work correctly with rotated videos. + + When a user specifies a transform with (H, W), they expect the final output to be + (H, W) regardless of the video's rotation metadata. This test verifies + that the transform is applied correctly such that the final output matches + the user's requested dimensions. + """ + decoder = VideoDecoder( + NASA_VIDEO_ROTATED.path, + transforms=[TransformClass((desired_H, desired_W))], + ) + frame = decoder[0] + + assert frame.shape == (3, desired_H, desired_W) + + # Also test batch APIs + frames = decoder.get_frames_at([0, 1]) + assert frames.data.shape == (2, 3, desired_H, desired_W) + + def test_rotation_with_transform_pipeline(self): + """Test that a pipeline of multiple transforms works correctly with rotated videos. + + This test verifies that chaining multiple transforms (e.g., Resize -> Resize -> Crop) + works as expected when the video has rotation metadata. Each transform should + operate on the output of the previous transform in post-rotation coordinate space. + """ + decoder = VideoDecoder( + NASA_VIDEO_ROTATED.path, + transforms=[Resize((400, 300)), Resize((300, 250)), CenterCrop((100, 100))], + ) + frame = decoder[0] + assert frame.shape == (3, 100, 100) + + frames = decoder.get_frames_at([0, 1]) + assert frames.data.shape == (2, 3, 100, 100) + + @needs_cuda + @pytest.mark.parametrize("device", cuda_devices()) + def test_cpu_fallback_h265_video(self, device): + """Test that H265 video triggers CPU fallback on CUDA interfaces.""" + # H265_VIDEO is known to trigger CPU fallback on CUDA + # because its dimensions are too small + decoder, _ = make_video_decoder(H265_VIDEO.path, device=device) + + if "beta" in device: + # For beta interface, status is known immediately + assert decoder.cpu_fallback.status_known + assert decoder.cpu_fallback + # Beta interface provides the specific reason for fallback + assert "Video not supported" in str(decoder.cpu_fallback) + else: + # For FFmpeg interface, status is unknown until first frame is decoded + assert not decoder.cpu_fallback.status_known + decoder.get_frame_at(0) + assert decoder.cpu_fallback.status_known + assert decoder.cpu_fallback + # FFmpeg interface doesn't know the specific reason + assert "Unknown reason - try the Beta interface to know more" in str( + decoder.cpu_fallback + ) + + @needs_cuda + @pytest.mark.parametrize("device", cuda_devices()) + def test_cpu_fallback_no_fallback_on_supported_video(self, device): + """Test that supported videos don't trigger fallback on CUDA.""" + decoder, _ = make_video_decoder(NASA_VIDEO.path, device=device) + + decoder[0] + + assert not decoder.cpu_fallback + assert "No fallback required" in str(decoder.cpu_fallback) + class TestAudioDecoder: - @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3, SINE_MONO_S32)) + @pytest.mark.parametrize( + "asset", (NASA_AUDIO, NASA_AUDIO_MP3, SINE_MONO_S32, SINE_16_CHANNEL_S16) + ) def test_metadata(self, asset): decoder = AudioDecoder(asset.path) assert isinstance(decoder.metadata, AudioStreamMetadata) @@ -1789,7 +2260,7 @@ def test_metadata(self, asset): ) expected_duration_seconds_from_header = asset.duration_seconds - if asset == NASA_AUDIO_MP3 and get_ffmpeg_major_version() >= 8: + if asset == NASA_AUDIO_MP3 and ffmpeg_major_version >= 8: expected_duration_seconds_from_header = 13.056 assert decoder.metadata.duration_seconds_from_header == pytest.approx( @@ -1835,7 +2306,7 @@ def test_get_all_samples_with_range(self, asset, stop_seconds): assert samples.sample_rate == asset.sample_rate assert samples.pts_seconds == asset.get_frame_info(idx=0).pts_seconds - @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) + @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3, SINE_16_CHANNEL_S16)) def test_get_all_samples(self, asset): decoder = AudioDecoder(asset.path) torch.testing.assert_close( @@ -2104,8 +2575,23 @@ def test_samples_duration(self, asset, sample_rate): # that the extra tensor allocation that happens within # maybeFlushSwrBuffers() is correct. @pytest.mark.parametrize("sample_rate", (None, 16_000)) - # FFmpeg can handle up to AV_NUM_DATA_POINTERS=8 channels - @pytest.mark.parametrize("num_channels", (1, 2, 8, None)) + @pytest.mark.parametrize( + "num_channels", + ( + 1, + 2, + 8, + 16, + pytest.param( + 24, + marks=pytest.mark.skipif( + ffmpeg_major_version == 4 and get_ffmpeg_minor_version() < 4, + reason="24 channel layout requires FFmpeg >= 4.4", + ), + ), + None, + ), + ) def test_num_channels(self, asset, sample_rate, num_channels): decoder = AudioDecoder( asset.path, sample_rate=sample_rate, num_channels=num_channels @@ -2119,12 +2605,33 @@ def test_num_channels(self, asset, sample_rate, num_channels): @pytest.mark.parametrize("asset", (SINE_MONO_S32, NASA_AUDIO_MP3)) def test_num_channels_errors(self, asset): - with pytest.raises( - RuntimeError, match="num_channels must be > 0 and <= AV_NUM_DATA_POINTERS" - ): + with pytest.raises(RuntimeError, match="num_channels must be > 0"): AudioDecoder(asset.path, num_channels=0) - with pytest.raises( - RuntimeError, match="num_channels must be > 0 and <= AV_NUM_DATA_POINTERS" - ): - # FFmpeg can handle up to AV_NUM_DATA_POINTERS=8 channels - AudioDecoder(asset.path, num_channels=9) + for num_channels in (15, 23): + with pytest.raises(RuntimeError, match="Couldn't initialize SwrContext:"): + decoder = AudioDecoder(asset.path, num_channels=num_channels) + # Call get_all_samples to trigger num_channels conversion. + # FFmpeg fails to find a default layout for certain channel counts, + # which causes SwrContext to fail to initialize. + decoder.get_all_samples() + + +class TestWavDecoder: + def test_metadata(self): + asset = SINE_MONO_S32 + wav_decoder = WavDecoder(asset.path) + audio_decoder = AudioDecoder(asset.path) + + assert isinstance(wav_decoder.metadata, AudioStreamMetadata) + assert wav_decoder.stream_index == audio_decoder.metadata.stream_index + assert wav_decoder.metadata == audio_decoder.metadata + + def test_tensor_handle_creation(self): + wav_dec = WavDecoder(SINE_MONO_S32.path) + assert wav_dec._decoder is not None + assert wav_dec.stream_index == 0 + assert wav_dec._source == SINE_MONO_S32.path + + def test_non_wav_file_raises_error(self): + with pytest.raises(RuntimeError, match="Missing RIFF header"): + WavDecoder(NASA_AUDIO.path) diff --git a/test/test_encoders.py b/test/test_encoders.py index b7223c88a..361a6eb04 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -1,6 +1,7 @@ import io import json import os +import platform import re import subprocess import sys @@ -9,24 +10,28 @@ import pytest import torch -from torchcodec.decoders import AudioDecoder +from torchcodec import ffmpeg_major_version +from torchcodec.decoders import AudioDecoder, VideoDecoder from torchcodec.encoders import AudioEncoder, VideoEncoder from .utils import ( assert_tensor_close_on_at_least, - get_ffmpeg_major_version, get_ffmpeg_minor_version, in_fbcode, + IN_GITHUB_CI, IS_WINDOWS, NASA_AUDIO_MP3, + needs_ffmpeg_cli, + psnr, SINE_MONO_S32, + TEST_SRC_2_720P, TestContainerFile, ) IS_WINDOWS_WITH_FFMPEG_LE_70 = IS_WINDOWS and ( - get_ffmpeg_major_version() < 7 - or (get_ffmpeg_major_version() == 7 and get_ffmpeg_minor_version() == 0) + ffmpeg_major_version < 7 + or (ffmpeg_major_version == 7 and get_ffmpeg_minor_version() == 0) ) @@ -44,6 +49,18 @@ def validate_frames_properties(*, actual: Path, expected: Path): # `ffprobe` on both, and assert that the frame properties match (pts, # duration, etc.) + # non-exhaustive list of the props we want to test for: + required_props = ( + "pts", + "pts_time", + "sample_fmt", + "nb_samples", + "channels", + "duration", + "duration_time", + ) + show_entries = "frame=" + ",".join(required_props) + frames_actual, frames_expected = ( json.loads( subprocess.run( @@ -55,6 +72,8 @@ def validate_frames_properties(*, actual: Path, expected: Path): "-select_streams", "a:0", "-show_frames", + "-show_entries", + show_entries, "-of", "json", f"{f}", @@ -76,21 +95,10 @@ def validate_frames_properties(*, actual: Path, expected: Path): assert len(frames_actual) > 3 # arbitrary sanity check assert len(frames_actual) == len(frames_expected) - # non-exhaustive list of the props we want to test for: - required_props = ( - "pts", - "pts_time", - "sample_fmt", - "nb_samples", - "channels", - "duration", - "duration_time", - ) - for frame_index, (d_actual, d_expected) in enumerate( zip(frames_actual, frames_expected) ): - if get_ffmpeg_major_version() >= 6: + if ffmpeg_major_version >= 6: assert all(required_prop in d_expected for required_prop in required_props) for prop in d_expected: @@ -215,13 +223,22 @@ def test_bad_input_parametrized(self, method, tmp_path): getattr(decoder, method)(**valid_params, num_channels=num_channels) @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) - @pytest.mark.parametrize("format", ("wav", "flac")) + @pytest.mark.parametrize( + "format", + [ + pytest.param( + "wav", + marks=pytest.mark.skipif( + ffmpeg_major_version == 4, + reason="Swresample with FFmpeg 4 doesn't work on wav files", + ), + ), + "flac", + ], + ) def test_round_trip(self, method, format, tmp_path): # Check that decode(encode(samples)) == samples on lossless formats - if get_ffmpeg_major_version() == 4 and format == "wav": - pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files") - asset = NASA_AUDIO_MP3 source_samples = self.decode(asset).data @@ -247,12 +264,32 @@ def test_round_trip(self, method, format, tmp_path): self.decode(encoded_source).data, source_samples, rtol=rtol, atol=atol ) - @pytest.mark.skipif(in_fbcode(), reason="TODO: enable ffmpeg CLI") + @needs_ffmpeg_cli @pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32)) @pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999)) @pytest.mark.parametrize("num_channels", (None, 1, 2)) @pytest.mark.parametrize("sample_rate", (8_000, 32_000)) - @pytest.mark.parametrize("format", ("mp3", "wav", "flac")) + @pytest.mark.parametrize( + "format", + [ + # TODO: https://github.com/pytorch/torchcodec/issues/837 + pytest.param( + "mp3", + marks=pytest.mark.skipif( + IS_WINDOWS and ffmpeg_major_version <= 5, + reason="Encoding mp3 on Windows is weirdly buggy", + ), + ), + pytest.param( + "wav", + marks=pytest.mark.skipif( + ffmpeg_major_version == 4, + reason="Swresample with FFmpeg 4 doesn't work on wav files", + ), + ), + "flac", + ], + ) @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) def test_against_cli( self, @@ -269,12 +306,6 @@ def test_against_cli( # Encodes samples with our encoder and with the FFmpeg CLI, and checks # that both decoded outputs are equal - if get_ffmpeg_major_version() == 4 and format == "wav": - pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files") - if IS_WINDOWS and get_ffmpeg_major_version() <= 5 and format == "mp3": - # TODO: https://github.com/pytorch/torchcodec/issues/837 - pytest.skip("Encoding mp3 on Windows is weirdly buggy") - encoded_by_ffmpeg = tmp_path / f"ffmpeg_output.{format}" subprocess.run( ["ffmpeg", "-i", str(asset.path)] @@ -316,7 +347,11 @@ def test_against_cli( assert_close = torch.testing.assert_close if sample_rate != asset.sample_rate: - rtol, atol = 0, 1e-3 + if platform.machine().lower() == "aarch64": + rtol, atol = 0, 1e-2 + else: + rtol, atol = 0, 1e-3 + if sys.platform == "darwin": assert_close = partial(assert_tensor_close_on_at_least, percentage=99) elif format == "wav": @@ -354,17 +389,31 @@ def test_against_cli( @pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32)) @pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999)) @pytest.mark.parametrize("num_channels", (None, 1, 2)) - @pytest.mark.parametrize("format", ("mp3", "wav", "flac")) + @pytest.mark.parametrize( + "format", + [ + # TODO: https://github.com/pytorch/torchcodec/issues/837 + pytest.param( + "mp3", + marks=pytest.mark.skipif( + IS_WINDOWS and ffmpeg_major_version <= 5, + reason="Encoding mp3 on Windows is weirdly buggy", + ), + ), + pytest.param( + "wav", + marks=pytest.mark.skipif( + ffmpeg_major_version == 4, + reason="Swresample with FFmpeg 4 doesn't work on wav files", + ), + ), + "flac", + ], + ) @pytest.mark.parametrize("method", ("to_tensor", "to_file_like")) def test_against_to_file( self, asset, bit_rate, num_channels, format, tmp_path, method ): - if get_ffmpeg_major_version() == 4 and format == "wav": - pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files") - if IS_WINDOWS and get_ffmpeg_major_version() <= 5 and format == "mp3": - # TODO: https://github.com/pytorch/torchcodec/issues/837 - pytest.skip("Encoding mp3 on Windows is weirdly buggy") - encoder = AudioEncoder(self.decode(asset).data, sample_rate=asset.sample_rate) params = dict(bit_rate=bit_rate, num_channels=num_channels) @@ -450,9 +499,6 @@ def encode_to_tensor(samples): encoded_from_contiguous, encoded_from_non_contiguous, rtol=0, atol=0 ) - @pytest.mark.skip( - reason="Flaky test, see https://github.com/pytorch/torchcodec/issues/724" - ) @pytest.mark.parametrize("num_channels_input", (1, 2)) @pytest.mark.parametrize("num_channels_output", (1, 2, None)) @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) @@ -567,6 +613,66 @@ def write(self, data): class TestVideoEncoder: + def decode(self, source=None) -> torch.Tensor: + return VideoDecoder(source).get_frames_in_range(start=0, stop=30).data + + # TODO: add average_fps field to TestVideo asset + def decode_and_get_frame_rate(self, source=None): + decoder = VideoDecoder(source) + frames = decoder.get_frames_in_range(start=0, stop=30).data + frame_rate = decoder.metadata.average_fps + return frames, frame_rate + + def _get_video_metadata(self, file_path, fields): + """Helper function to get video metadata from a file using ffprobe.""" + result = subprocess.run( + [ + "ffprobe", + "-v", + "error", + "-select_streams", + "v:0", + "-show_entries", + f"stream={','.join(fields)}", + "-of", + "default=noprint_wrappers=1", + str(file_path), + ], + capture_output=True, + check=True, + text=True, + ) + metadata = {} + for line in result.stdout.strip().split("\n"): + if "=" in line: + key, value = line.split("=", 1) + metadata[key] = value + assert all(field in metadata for field in fields) + return metadata + + def _get_frames_info(self, file_path, fields): + """Helper function to get frame info (pts, dts, etc.) using ffprobe.""" + result = subprocess.run( + [ + "ffprobe", + "-v", + "error", + "-select_streams", + "v:0", + "-show_entries", + f"frame={','.join(fields)}", + "-of", + "json", + str(file_path), + ], + capture_output=True, + check=True, + text=True, + ) + frames = json.loads(result.stdout)["frames"] + assert all(field in frame for field in fields for frame in frames) + return frames + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) def test_bad_input_parameterized(self, tmp_path, method): if method == "to_file": @@ -605,6 +711,47 @@ def test_bad_input_parameterized(self, tmp_path, method): ) getattr(encoder, method)(**valid_params) + with pytest.raises( + RuntimeError, + match=r"Video codec invalid_codec_name not found.", + ): + encoder = VideoEncoder( + frames=torch.zeros((5, 3, 64, 64), dtype=torch.uint8), + frame_rate=30, + ) + encoder.to_file(str(tmp_path / "output.mp4"), codec="invalid_codec_name") + + with pytest.raises(RuntimeError, match=r"crf=-10 is out of valid range"): + encoder = VideoEncoder( + frames=torch.zeros((5, 3, 64, 64), dtype=torch.uint8), + frame_rate=30, + ) + getattr(encoder, method)(**valid_params, crf=-10) + + with pytest.raises( + RuntimeError, + match=r"avcodec_open2 failed: Invalid argument", + ): + encoder.to_tensor(format="mp4", preset="fake_preset") + + @pytest.mark.parametrize("method", ["to_file", "to_tensor", "to_file_like"]) + @pytest.mark.parametrize("crf", [23, 23.5, -0.9]) + def test_crf_valid_values(self, method, crf, tmp_path): + if method == "to_file": + valid_params = {"dest": str(tmp_path / "test.mp4")} + elif method == "to_tensor": + valid_params = {"format": "mp4"} + elif method == "to_file_like": + valid_params = dict(file_like=io.BytesIO(), format="mp4") + else: + raise ValueError(f"Unknown method: {method}") + + encoder = VideoEncoder( + frames=torch.zeros((5, 3, 64, 64), dtype=torch.uint8), + frame_rate=30, + ) + getattr(encoder, method)(**valid_params, crf=crf) + def test_bad_input(self, tmp_path): encoder = VideoEncoder( frames=torch.zeros((5, 3, 64, 64), dtype=torch.uint8), @@ -630,15 +777,112 @@ def test_bad_input(self, tmp_path): encoder.to_tensor(format="bad_format") @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) - def test_contiguity(self, method, tmp_path): + @pytest.mark.parametrize( + "device", ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda)) + ) + def test_pixel_format_errors(self, method, device, tmp_path): + frames = torch.zeros((5, 3, 64, 64), dtype=torch.uint8).to(device) + encoder = VideoEncoder(frames, frame_rate=30) + + if method == "to_file": + valid_params = dict(dest=str(tmp_path / "output.mp4")) + elif method == "to_tensor": + valid_params = dict(format="mp4") + elif method == "to_file_like": + valid_params = dict(file_like=io.BytesIO(), format="mp4") + + if device == "cuda": + with pytest.raises( + RuntimeError, + match="Video encoding on GPU currently only supports the nv12 pixel format. Do not set pixel_format to use nv12 by default.", + ): + getattr(encoder, method)(**valid_params, pixel_format="yuv444p") + return + + with pytest.raises( + RuntimeError, + match=r"Unknown pixel format: invalid_pix_fmt[\s\S]*Supported pixel formats.*yuv420p", + ): + getattr(encoder, method)(**valid_params, pixel_format="invalid_pix_fmt") + + with pytest.raises( + RuntimeError, + match=r"Specified pixel format rgb24 is not supported[\s\S]*Supported pixel formats.*yuv420p", + ): + getattr(encoder, method)(**valid_params, pixel_format="rgb24") + + @pytest.mark.parametrize( + "extra_options,error", + [ + ({"qp": -10}, "qp=-10 is out of valid range"), + ( + {"qp": ""}, + "Option qp expects a numeric value but got", + ), + ( + {"direct-pred": "a"}, + "Option direct-pred expects a numeric value but got 'a'", + ), + ({"tune": "not_a_real_tune"}, "avcodec_open2 failed: Invalid argument"), + ( + {"tune": 10}, + "avcodec_open2 failed: Invalid argument", + ), + ], + ) + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) + def test_extra_options_errors(self, method, tmp_path, extra_options, error): + frames = torch.zeros((5, 3, 64, 64), dtype=torch.uint8) + encoder = VideoEncoder(frames, frame_rate=30) + + if method == "to_file": + valid_params = dict(dest=str(tmp_path / "output.mp4")) + elif method == "to_tensor": + valid_params = dict(format="mp4") + elif method == "to_file_like": + valid_params = dict(file_like=io.BytesIO(), format="mp4") + else: + raise ValueError(f"Unknown method: {method}") + + with pytest.raises( + RuntimeError, + match=error, + ): + getattr(encoder, method)(**valid_params, extra_options=extra_options) + + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) + @pytest.mark.parametrize( + "device", + ( + "cpu", + pytest.param( + "cuda", + marks=[ + pytest.mark.needs_cuda, + pytest.mark.skipif( + in_fbcode(), reason="NVENC not available in fbcode" + ), + pytest.mark.skipif( + ffmpeg_major_version == 4, + reason="CUDA + FFmpeg 4 test is flaky", + ), + ], + ), + ), + ) + def test_contiguity(self, method, tmp_path, device): # Ensure that 2 sets of video frames with the same pixel values are encoded # in the same way, regardless of their memory layout. Here we encode 2 equal # frame tensors, one is contiguous while the other is non-contiguous. - num_frames, channels, height, width = 5, 3, 64, 64 - contiguous_frames = torch.randint( - 0, 256, size=(num_frames, channels, height, width), dtype=torch.uint8 - ).contiguous() + num_frames, channels, height, width = 5, 3, 256, 256 + contiguous_frames = ( + torch.randint( + 0, 256, size=(num_frames, channels, height, width), dtype=torch.uint8 + ) + .contiguous() + .to(device) + ) assert contiguous_frames.is_contiguous() # Permute NCHW to NHWC, then update the memory layout, then permute back @@ -654,17 +898,23 @@ def test_contiguity(self, method, tmp_path): ) def encode_to_tensor(frames): + common_params = dict( + crf=0, + pixel_format="yuv444p" if device == "cpu" else None, + ) if method == "to_file": dest = str(tmp_path / "output.mp4") - VideoEncoder(frames, frame_rate=30).to_file(dest=dest) + VideoEncoder(frames, frame_rate=30).to_file(dest=dest, **common_params) with open(dest, "rb") as f: return torch.frombuffer(f.read(), dtype=torch.uint8) elif method == "to_tensor": - return VideoEncoder(frames, frame_rate=30).to_tensor(format="mp4") + return VideoEncoder(frames, frame_rate=30).to_tensor( + format="mp4", **common_params + ) elif method == "to_file_like": file_like = io.BytesIO() VideoEncoder(frames, frame_rate=30).to_file_like( - file_like, format="mp4" + file_like, format="mp4", **common_params ) return torch.frombuffer(file_like.getvalue(), dtype=torch.uint8) else: @@ -676,3 +926,637 @@ def encode_to_tensor(frames): torch.testing.assert_close( encoded_from_contiguous, encoded_from_non_contiguous, rtol=0, atol=0 ) + + @pytest.mark.parametrize( + "format", + [ + "mov", + "mp4", + "mkv", + pytest.param( + "webm", + marks=[ + pytest.mark.slow, + pytest.mark.skipif( + ffmpeg_major_version == 4 + or (IS_WINDOWS and ffmpeg_major_version >= 6), + reason="Codec for webm is not available in this FFmpeg installation.", + ), + ], + ), + ], + ) + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) + def test_round_trip(self, tmp_path, format, method): + # Test that decode(encode(decode(frames))) == decode(frames) + source_frames, frame_rate = self.decode_and_get_frame_rate(TEST_SRC_2_720P.path) + + encoder = VideoEncoder(frames=source_frames, frame_rate=frame_rate) + + if method == "to_file": + encoded_path = str(tmp_path / f"encoder_output.{format}") + encoder.to_file(dest=encoded_path, pixel_format="yuv444p", crf=0) + round_trip_frames = self.decode(encoded_path) + elif method == "to_tensor": + encoded_tensor = encoder.to_tensor( + format=format, pixel_format="yuv444p", crf=0 + ) + round_trip_frames = self.decode(encoded_tensor) + elif method == "to_file_like": + file_like = io.BytesIO() + encoder.to_file_like( + file_like=file_like, format=format, pixel_format="yuv444p", crf=0 + ) + round_trip_frames = self.decode(file_like.getvalue()) + else: + raise ValueError(f"Unknown method: {method}") + + assert source_frames.shape == round_trip_frames.shape + assert source_frames.dtype == round_trip_frames.dtype + + atol = 3 if format == "webm" else 2 + for s_frame, rt_frame in zip(source_frames, round_trip_frames): + assert psnr(s_frame, rt_frame) > 30 + torch.testing.assert_close(s_frame, rt_frame, atol=atol, rtol=0) + + @pytest.mark.parametrize( + "format", + [ + "mov", + "mp4", + "avi", + "mkv", + "flv", + "gif", + pytest.param( + "webm", + marks=[ + pytest.mark.slow, + pytest.mark.skipif( + ffmpeg_major_version == 4 + or (IS_WINDOWS and ffmpeg_major_version >= 6), + reason="Codec for webm is not available in this FFmpeg installation.", + ), + ], + ), + ], + ) + @pytest.mark.parametrize("method", ("to_tensor", "to_file_like")) + def test_against_to_file(self, tmp_path, format, method): + # Test that to_file, to_tensor, and to_file_like produce the same results + source_frames, frame_rate = self.decode_and_get_frame_rate(TEST_SRC_2_720P.path) + encoder = VideoEncoder(frames=source_frames, frame_rate=frame_rate) + + encoded_file = tmp_path / f"output.{format}" + encoder.to_file(dest=encoded_file, crf=0) + + if method == "to_tensor": + encoded_output = encoder.to_tensor(format=format, crf=0) + else: # to_file_like + file_like = io.BytesIO() + encoder.to_file_like(file_like=file_like, format=format, crf=0) + encoded_output = file_like.getvalue() + + torch.testing.assert_close( + self.decode(encoded_file), + self.decode(encoded_output), + atol=0, + rtol=0, + ) + + @needs_ffmpeg_cli + @pytest.mark.parametrize( + "format", + ( + "mov", + "mp4", + "avi", + "mkv", + "flv", + pytest.param( + "webm", + marks=[ + pytest.mark.slow, + pytest.mark.skipif( + ffmpeg_major_version == 4 + or (IS_WINDOWS and ffmpeg_major_version >= 6), + reason="Codec for webm is not available in this FFmpeg installation.", + ), + ], + ), + ), + ) + @pytest.mark.parametrize( + "encode_params", + [ + {"pixel_format": "yuv444p", "crf": 0, "preset": None}, + {"pixel_format": "yuv420p", "crf": 30, "preset": None}, + {"pixel_format": "yuv420p", "crf": None, "preset": "ultrafast"}, + {"pixel_format": "yuv420p", "crf": None, "preset": None}, + ], + ) + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) + @pytest.mark.parametrize("frame_rate", [30, 29.97]) + def test_video_encoder_against_ffmpeg_cli( + self, tmp_path, format, encode_params, method, frame_rate + ): + pixel_format = encode_params["pixel_format"] + crf = encode_params["crf"] + preset = encode_params["preset"] + + if format in ("avi", "flv") and pixel_format == "yuv444p": + pytest.skip(f"Default codec for {format} does not support {pixel_format}") + + source_frames = self.decode(TEST_SRC_2_720P.path) + + # Encode with FFmpeg CLI + temp_raw_path = str(tmp_path / "temp_input.raw") + with open(temp_raw_path, "wb") as f: + f.write(source_frames.permute(0, 2, 3, 1).cpu().numpy().tobytes()) + + ffmpeg_encoded_path = str(tmp_path / f"ffmpeg_output.{format}") + # Some codecs (ex. MPEG4) do not support CRF or preset. + # Flags not supported by the selected codec will be ignored. + ffmpeg_cmd = [ + "ffmpeg", + "-y", + "-f", + "rawvideo", + "-pix_fmt", + "rgb24", # Input format + "-s", + f"{source_frames.shape[3]}x{source_frames.shape[2]}", + "-r", + str(frame_rate), + "-i", + temp_raw_path, + ] + if pixel_format is not None: # Output format + ffmpeg_cmd.extend(["-pix_fmt", pixel_format]) + if preset is not None: + ffmpeg_cmd.extend(["-preset", preset]) + if crf is not None: + ffmpeg_cmd.extend(["-crf", str(crf)]) + # Output path must be last + ffmpeg_cmd.append(ffmpeg_encoded_path) + subprocess.run(ffmpeg_cmd, check=True) + ffmpeg_frames = self.decode(ffmpeg_encoded_path).data + + # Encode with our video encoder + encoder = VideoEncoder(frames=source_frames, frame_rate=frame_rate) + encoder_output_path = str(tmp_path / f"encoder_output.{format}") + + if method == "to_file": + encoder.to_file( + dest=encoder_output_path, + pixel_format=pixel_format, + crf=crf, + preset=preset, + ) + encoder_frames = self.decode(encoder_output_path) + elif method == "to_tensor": + encoded_output = encoder.to_tensor( + format=format, + pixel_format=pixel_format, + crf=crf, + preset=preset, + ) + encoder_frames = self.decode(encoded_output) + elif method == "to_file_like": + file_like = io.BytesIO() + encoder.to_file_like( + file_like=file_like, + format=format, + pixel_format=pixel_format, + crf=crf, + preset=preset, + ) + encoder_frames = self.decode(file_like.getvalue()) + else: + raise ValueError(f"Unknown method: {method}") + + assert ffmpeg_frames.shape[0] == encoder_frames.shape[0] + + # MPEG codec used for avi format does not accept CRF + percentage = 94 if format == "avi" else 99 + + # Check that PSNR between both encoded versions is high + for ff_frame, enc_frame in zip(ffmpeg_frames, encoder_frames): + res = psnr(ff_frame, enc_frame) + assert res > 30 + assert_tensor_close_on_at_least( + ff_frame, enc_frame, percentage=percentage, atol=2 + ) + + # Only compare video metadata on ffmpeg versions >= 6, as older versions + # are often missing metadata + if ffmpeg_major_version >= 6 and method == "to_file": + fields = [ + "duration", + "duration_ts", + "r_frame_rate", + "time_base", + "nb_frames", + ] + ffmpeg_metadata = self._get_video_metadata( + ffmpeg_encoded_path, + fields=fields, + ) + encoder_metadata = self._get_video_metadata( + encoder_output_path, + fields=fields, + ) + assert ffmpeg_metadata == encoder_metadata + + # Check that frame timestamps and duration are the same + fields = ("pts", "pts_time") + if format != "flv": + fields += ("duration", "duration_time") + ffmpeg_frames_info = self._get_frames_info( + ffmpeg_encoded_path, fields=fields + ) + encoder_frames_info = self._get_frames_info( + encoder_output_path, fields=fields + ) + assert ffmpeg_frames_info == encoder_frames_info + + def test_to_file_like_custom_file_object(self): + """Test to_file_like with a custom file-like object that implements write and seek.""" + + class CustomFileObject: + def __init__(self): + self._file = io.BytesIO() + + def write(self, data): + return self._file.write(data) + + def seek(self, offset, whence=0): + return self._file.seek(offset, whence) + + def get_encoded_data(self): + return self._file.getvalue() + + source_frames, frame_rate = self.decode_and_get_frame_rate(TEST_SRC_2_720P.path) + encoder = VideoEncoder(frames=source_frames, frame_rate=frame_rate) + + file_like = CustomFileObject() + encoder.to_file_like(file_like, format="mp4", pixel_format="yuv444p", crf=0) + decoded_frames = self.decode(file_like.get_encoded_data()) + + torch.testing.assert_close( + decoded_frames, + source_frames, + atol=2, + rtol=0, + ) + + def test_to_file_like_real_file(self, tmp_path): + """Test to_file_like with a real file opened in binary write mode.""" + source_frames, frame_rate = self.decode_and_get_frame_rate(TEST_SRC_2_720P.path) + encoder = VideoEncoder(frames=source_frames, frame_rate=frame_rate) + + file_path = tmp_path / "test_file_like.mp4" + + with open(file_path, "wb") as file_like: + encoder.to_file_like(file_like, format="mp4", pixel_format="yuv444p", crf=0) + decoded_frames = self.decode(str(file_path)) + + torch.testing.assert_close( + decoded_frames, + source_frames, + atol=2, + rtol=0, + ) + + def test_to_file_like_bad_methods(self): + source_frames, frame_rate = self.decode_and_get_frame_rate(TEST_SRC_2_720P.path) + encoder = VideoEncoder(frames=source_frames, frame_rate=frame_rate) + + class NoWriteMethod: + def seek(self, offset, whence=0): + return 0 + + with pytest.raises( + RuntimeError, match="File like object must implement a write method" + ): + encoder.to_file_like(NoWriteMethod(), format="mp4") + + class NoSeekMethod: + def write(self, data): + return len(data) + + with pytest.raises( + RuntimeError, match="File like object must implement a seek method" + ): + encoder.to_file_like(NoSeekMethod(), format="mp4") + + @needs_ffmpeg_cli + @pytest.mark.parametrize( + "format,codec_spec", + [ + ("mp4", "h264"), + ("mp4", "hevc"), + ("mkv", "av1"), + ("avi", "mpeg4"), + pytest.param( + "webm", + "vp9", + marks=pytest.mark.skipif( + IS_WINDOWS, reason="vp9 codec not available on Windows" + ), + ), + ], + ) + def test_codec_parameter_utilized(self, tmp_path, format, codec_spec): + # Test the codec parameter is utilized by using ffprobe to check the encoded file's codec spec + frames = torch.zeros((10, 3, 64, 64), dtype=torch.uint8) + dest = str(tmp_path / f"output.{format}") + + VideoEncoder(frames=frames, frame_rate=30).to_file(dest=dest, codec=codec_spec) + actual_codec_spec = self._get_video_metadata(dest, fields=["codec_name"])[ + "codec_name" + ] + assert actual_codec_spec == codec_spec + + @needs_ffmpeg_cli + @pytest.mark.parametrize( + "codec_spec,codec_impl", + [ + ("h264", "libx264"), + ("av1", "libaom-av1"), + pytest.param( + "vp9", + "libvpx-vp9", + marks=pytest.mark.skipif( + IS_WINDOWS, reason="vp9 codec not available on Windows" + ), + ), + ], + ) + def test_codec_spec_vs_impl_equivalence(self, tmp_path, codec_spec, codec_impl): + # Test that using codec spec gives the same result as using default codec implementation + # We cannot directly check codec impl used, so we assert frame equality + frames = torch.randint(0, 256, (10, 3, 64, 64), dtype=torch.uint8) + + spec_output = str(tmp_path / "spec_output.mp4") + VideoEncoder(frames=frames, frame_rate=30).to_file( + dest=spec_output, codec=codec_spec + ) + + impl_output = str(tmp_path / "impl_output.mp4") + VideoEncoder(frames=frames, frame_rate=30).to_file( + dest=impl_output, codec=codec_impl + ) + + assert ( + self._get_video_metadata(spec_output, fields=["codec_name"])["codec_name"] + == codec_spec + ) + assert ( + self._get_video_metadata(impl_output, fields=["codec_name"])["codec_name"] + == codec_spec + ) + + frames_spec = self.decode(spec_output) + frames_impl = self.decode(impl_output) + torch.testing.assert_close(frames_spec, frames_impl, rtol=0, atol=0) + + @needs_ffmpeg_cli + @pytest.mark.parametrize( + "profile,colorspace,color_range", + [ + ("baseline", "bt709", "tv"), + ("main", "bt470bg", "pc"), + ("high", "fcc", "pc"), + ], + ) + def test_extra_options_utilized(self, tmp_path, profile, colorspace, color_range): + # Test setting profile, colorspace, and color_range via extra_options is utilized + source_frames = torch.zeros((5, 3, 64, 64), dtype=torch.uint8) + encoder = VideoEncoder(frames=source_frames, frame_rate=30) + + output_path = str(tmp_path / "output.mp4") + encoder.to_file( + dest=output_path, + extra_options={ + "profile": profile, + "colorspace": colorspace, + "color_range": color_range, + }, + ) + metadata = self._get_video_metadata( + output_path, + fields=["profile", "color_space", "color_range"], + ) + # Validate profile (case-insensitive, baseline is reported as "Constrained Baseline") + expected_profile = "constrained baseline" if profile == "baseline" else profile + assert metadata["profile"].lower() == expected_profile + assert metadata["color_space"] == colorspace + assert metadata["color_range"] == color_range + + @needs_ffmpeg_cli + @pytest.mark.needs_cuda + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) + @pytest.mark.parametrize( + ("format", "codec"), + [ + ("mov", None), # will default to h264_nvenc + ("mov", "h264_nvenc"), + ("avi", "h264_nvenc"), + ("mp4", "hevc_nvenc"), # use non-default codec + pytest.param( + "mkv", + "av1_nvenc", + marks=[ + pytest.mark.skipif( + IN_GITHUB_CI, reason="av1_nvenc is not supported on CI" + ), + pytest.mark.skipif( + ffmpeg_major_version == 4, + reason="av1_nvenc is not supported on FFmpeg 4", + ), + ], + ), + ], + ) + # We test the color space and color range parameters in this test, because + # we are required to define matrices specific to these specs when using NPP, see note: + # [RGB -> YUV Color Conversion, limited color range] + # BT.601, BT.709, BT.2020 + @pytest.mark.parametrize("color_space", ("bt470bg", "bt709", "bt2020nc", None)) + # Full/PC range, Limited/TV range + @pytest.mark.parametrize("color_range", ("pc", "tv", None)) + def test_nvenc_against_ffmpeg_cli( + self, tmp_path, method, format, codec, color_space, color_range + ): + # TODO-VideoEncoder: (P2) Investigate why FFmpeg 4 and 6 fail with non-default color space and range. + # See https://github.com/meta-pytorch/torchcodec/issues/1140 + if ffmpeg_major_version in (4, 6) and not ( + color_space == "bt470bg" and color_range == "tv" + ): + pytest.skip( + "Non-default color space and range have lower accuracy on FFmpeg 4 and 6" + ) + # Encode with FFmpeg CLI using nvenc codecs + device = "cuda" + qp = 1 # Use near lossless encoding to reduce noise and support av1_nvenc + source_frames = self.decode(TEST_SRC_2_720P.path).data.to(device) + + temp_raw_path = str(tmp_path / "temp_input.raw") + with open(temp_raw_path, "wb") as f: + f.write(source_frames.permute(0, 2, 3, 1).cpu().numpy().tobytes()) + + ffmpeg_encoded_path = str(tmp_path / f"ffmpeg_nvenc_output.{format}") + frame_rate = 30 + + ffmpeg_cmd = [ + "ffmpeg", + "-y", + "-f", + "rawvideo", + "-pix_fmt", + "rgb24", # Input format + "-s", + f"{source_frames.shape[3]}x{source_frames.shape[2]}", + "-r", + str(frame_rate), + "-i", + temp_raw_path, + ] + # CLI requires explicit codec for nvenc + # VideoEncoder will default to h264_nvenc since the frames are on GPU. + ffmpeg_cmd.extend(["-c:v", codec if codec is not None else "h264_nvenc"]) + ffmpeg_cmd.extend(["-pix_fmt", "nv12"]) # Output format is always NV12 + ffmpeg_cmd.extend(["-qp", str(qp)]) + if color_space: + ffmpeg_cmd.extend(["-colorspace", color_space]) + if color_range: + ffmpeg_cmd.extend(["-color_range", color_range]) + ffmpeg_cmd.extend([ffmpeg_encoded_path]) + subprocess.run(ffmpeg_cmd, check=True, capture_output=True) + + encoder = VideoEncoder(frames=source_frames, frame_rate=frame_rate) + encoder_extra_options = {"qp": qp} + if color_space: + encoder_extra_options["colorspace"] = color_space + if color_range: + encoder_extra_options["color_range"] = color_range + if method == "to_file": + encoder_output_path = str(tmp_path / f"nvenc_output.{format}") + encoder.to_file( + dest=encoder_output_path, + codec=codec, + extra_options=encoder_extra_options, + ) + encoder_output = encoder_output_path + elif method == "to_tensor": + encoder_output = encoder.to_tensor( + format=format, + codec=codec, + extra_options=encoder_extra_options, + ) + elif method == "to_file_like": + file_like = io.BytesIO() + encoder.to_file_like( + file_like=file_like, + format=format, + codec=codec, + extra_options=encoder_extra_options, + ) + encoder_output = file_like.getvalue() + else: + raise ValueError(f"Unknown method: {method}") + + ffmpeg_frames = self.decode(ffmpeg_encoded_path).data + encoder_frames = self.decode(encoder_output).data + + assert ffmpeg_frames.shape[0] == encoder_frames.shape[0] + for ff_frame, enc_frame in zip(ffmpeg_frames, encoder_frames): + assert psnr(ff_frame, enc_frame) > 25 + assert_tensor_close_on_at_least(ff_frame, enc_frame, percentage=96, atol=2) + + if method == "to_file": + metadata_fields = ["pix_fmt", "color_range", "color_space"] + ffmpeg_metadata = self._get_video_metadata( + ffmpeg_encoded_path, metadata_fields + ) + encoder_metadata = self._get_video_metadata(encoder_output, metadata_fields) + # pix_fmt nv12 is stored as yuv420p in metadata, unless full range (pc)is used + # In that case, h264 and hevc NVENC codecs will use yuvj420p automatically. + if color_range == "pc" and codec != "av1_nvenc": + expected_pix_fmt = "yuvj420p" + else: + # av1_nvenc does not utilize the yuvj420p pixel format + expected_pix_fmt = "yuv420p" + assert ( + encoder_metadata["pix_fmt"] + == ffmpeg_metadata["pix_fmt"] + == expected_pix_fmt + ) + + assert encoder_metadata["color_range"] == ffmpeg_metadata["color_range"] + assert encoder_metadata["color_space"] == ffmpeg_metadata["color_space"] + # Default values vary by codec, so we only assert when + # color_range and color_space are not None. + if color_range is not None: + # FFmpeg and torchcodec encode color_range as 'unknown' for mov and avi + # when color_range='tv' and color_space=None on FFmpeg 7/8. + # Since this failure is rare, I suspect its a bug related to these + # older container formats on newer FFmpeg versions. + if not ( + ffmpeg_major_version in (7, 8) + and color_range == "tv" + and color_space is None + and format in ("mov", "avi") + ): + assert color_range == encoder_metadata["color_range"] + if color_space is not None: + assert color_space == encoder_metadata["color_space"] + + @pytest.mark.skipif( + ffmpeg_major_version == 4, + reason="On FFmpeg 4 hitting a truncated packet results in AVERROR_INVALIDDATA, which torchcodec does not handle.", + ) + @pytest.mark.parametrize("format", ["mp4", "mov"]) + @pytest.mark.parametrize( + "extra_options", + [ + # frag_keyframe with empty_moov (new fragment every keyframe) + {"movflags": "+frag_keyframe+empty_moov"}, + # frag_duration creates fragments based on duration (in microseconds) + {"movflags": "+empty_moov", "frag_duration": "1000000"}, + ], + ) + def test_fragmented_mp4( + self, + tmp_path, + extra_options, + format, + ): + # Test that VideoEncoder can write fragmented files using movflags. + # Fragmented files store metadata interleaved with data rather than + # all at the end, making them decodable even if writing is interrupted. + source_frames, frame_rate = self.decode_and_get_frame_rate(TEST_SRC_2_720P.path) + encoder = VideoEncoder(frames=source_frames, frame_rate=frame_rate) + encoded_path = str(tmp_path / f"fragmented_output.{format}") + encoder.to_file(dest=encoded_path, extra_options=extra_options) + + # Decode the file to get reference frames + reference_decoder = VideoDecoder(encoded_path) + reference_frames = [reference_decoder.get_frame_at(i) for i in range(10)] + + # Truncate the file to simulate interrupted write + with open(encoded_path, "rb") as f: + full_content = f.read() + truncated_size = int(len(full_content) * 0.5) + with open(encoded_path, "wb") as f: + f.write(full_content[:truncated_size]) + + # Decode the truncated file and verify first 10 frames match reference + truncated_decoder = VideoDecoder(encoded_path) + assert len(truncated_decoder) >= 10 + for i in range(10): + truncated_frame = truncated_decoder.get_frame_at(i) + torch.testing.assert_close( + truncated_frame.data, reference_frames[i].data, atol=0, rtol=0 + ) diff --git a/test/test_metadata.py b/test/test_metadata.py index 628b7a68d..0a0227737 100644 --- a/test/test_metadata.py +++ b/test/test_metadata.py @@ -8,18 +8,22 @@ from fractions import Fraction import pytest - +from torchcodec import ffmpeg_major_version from torchcodec._core import ( - add_video_stream, AudioStreamMetadata, - create_from_file, get_container_metadata, get_container_metadata_from_header, VideoStreamMetadata, ) +from torchcodec._core.ops import add_video_stream, create_from_file from torchcodec.decoders import AudioDecoder, VideoDecoder -from .utils import get_ffmpeg_major_version, NASA_AUDIO_MP3, NASA_VIDEO +from .utils import ( + BT2020_LIMITED_RANGE_10BIT, + NASA_AUDIO_MP3, + NASA_VIDEO, + NASA_VIDEO_ROTATED, +) # TODO: Expected values in these tests should be based on the assets's @@ -48,7 +52,15 @@ def _get_container_metadata(path, seek_mode): get_container_metadata_from_header, functools.partial(_get_container_metadata, seek_mode="approximate"), functools.partial(_get_container_metadata, seek_mode="exact"), - functools.partial(_get_container_metadata, seek_mode="custom_frame_mappings"), + pytest.param( + functools.partial( + _get_container_metadata, seek_mode="custom_frame_mappings" + ), + marks=pytest.mark.skipif( + ffmpeg_major_version in (4, 5), + reason="ffprobe isn't accurate on ffmpeg 4 and 5", + ), + ), ), ) def test_get_metadata(metadata_getter): @@ -57,9 +69,6 @@ def test_get_metadata(metadata_getter): if isinstance(metadata_getter, functools.partial) else None ) - if (seek_mode == "custom_frame_mappings") and get_ffmpeg_major_version() in (4, 5): - pytest.skip(reason="ffprobe isn't accurate on ffmpeg 4 and 5") - with_added_video_stream = seek_mode == "custom_frame_mappings" metadata = metadata_getter(NASA_VIDEO.path) with_scan = ( @@ -77,7 +86,6 @@ def test_get_metadata(metadata_getter): with pytest.raises(NotImplementedError, match="Decide on logic"): metadata.bit_rate - ffmpeg_major_version = get_ffmpeg_major_version() if ffmpeg_major_version <= 5: expected_duration_seconds_from_header = 16.57 expected_bit_rate_from_header = 324915 @@ -99,9 +107,8 @@ def test_get_metadata(metadata_getter): assert best_video_stream_metadata.begin_stream_seconds_from_header == 0 assert best_video_stream_metadata.bit_rate == 128783 assert best_video_stream_metadata.average_fps == pytest.approx(29.97, abs=0.001) - assert best_video_stream_metadata.pixel_aspect_ratio == ( - Fraction(1, 1) if with_added_video_stream else None - ) + assert best_video_stream_metadata.pixel_aspect_ratio == Fraction(1, 1) + assert best_video_stream_metadata.pixel_format == "yuv420p" assert best_video_stream_metadata.codec == "h264" assert best_video_stream_metadata.num_frames_from_content == ( 390 if with_scan else None @@ -132,7 +139,6 @@ def test_get_metadata_audio_file(metadata_getter): assert isinstance(best_audio_stream_metadata, AudioStreamMetadata) assert best_audio_stream_metadata is metadata.best_audio_stream - ffmpeg_major_version = get_ffmpeg_major_version() expected_duration_seconds_from_header = ( 13.056 if ffmpeg_major_version >= 8 else 13.248 ) @@ -147,121 +153,39 @@ def test_get_metadata_audio_file(metadata_getter): assert best_audio_stream_metadata.sample_format == "fltp" -@pytest.mark.parametrize( - "num_frames_from_header, num_frames_from_content, expected_num_frames", - [(10, 20, 20), (None, 10, 10), (10, None, 10)], -) -def test_num_frames_fallback( - num_frames_from_header, num_frames_from_content, expected_num_frames -): - """Check that num_frames_from_content always has priority when accessing `.num_frames`""" - metadata = VideoStreamMetadata( - duration_seconds_from_header=4, - bit_rate=123, - num_frames_from_header=num_frames_from_header, - num_frames_from_content=num_frames_from_content, - begin_stream_seconds_from_header=0, - begin_stream_seconds_from_content=0, - end_stream_seconds_from_content=4, - codec="whatever", - width=123, - height=321, - average_fps_from_header=30, - pixel_aspect_ratio=Fraction(1, 1), - stream_index=0, - ) - - assert metadata.num_frames == expected_num_frames - - -@pytest.mark.parametrize( - "average_fps_from_header, duration_seconds_from_header, expected_num_frames", - [(60, 10, 600), (60, None, None), (None, 10, None), (None, None, None)], -) -def test_calculate_num_frames_using_fps_and_duration( - average_fps_from_header, duration_seconds_from_header, expected_num_frames -): - """Check that if num_frames_from_content and num_frames_from_header are missing, - `.num_frames` is calculated using average_fps_from_header and duration_seconds_from_header - """ - metadata = VideoStreamMetadata( - duration_seconds_from_header=duration_seconds_from_header, - bit_rate=123, - num_frames_from_header=None, # None to test calculating num_frames - num_frames_from_content=None, # None to test calculating num_frames - begin_stream_seconds_from_header=0, - begin_stream_seconds_from_content=0, - end_stream_seconds_from_content=4, - codec="whatever", - width=123, - height=321, - pixel_aspect_ratio=Fraction(10, 11), - average_fps_from_header=average_fps_from_header, - stream_index=0, - ) - - assert metadata.num_frames == expected_num_frames +def test_rotation_metadata(): + """Test that rotation metadata is correctly extracted for rotated video.""" + # NASA_VIDEO_ROTATED has 90-degree rotation metadata + decoder_rotated = VideoDecoder(NASA_VIDEO_ROTATED.path) + assert decoder_rotated.metadata.rotation is not None + assert decoder_rotated.metadata.rotation == 90 + # NASA_VIDEO has no rotation metadata + decoder = VideoDecoder(NASA_VIDEO.path) + assert decoder.metadata.rotation is None -@pytest.mark.parametrize( - "duration_seconds_from_header, begin_stream_seconds_from_content, end_stream_seconds_from_content, expected_duration_seconds", - [(60, 5, 20, 15), (60, 1, None, 60), (60, None, 1, 60), (None, 0, 10, 10)], -) -def test_duration_seconds_fallback( - duration_seconds_from_header, - begin_stream_seconds_from_content, - end_stream_seconds_from_content, - expected_duration_seconds, -): - """Check that using begin_stream_seconds_from_content and end_stream_seconds_from_content to calculate `.duration_seconds` - has priority. If either value is missing, duration_seconds_from_header is used. - """ - metadata = VideoStreamMetadata( - duration_seconds_from_header=duration_seconds_from_header, - bit_rate=123, - num_frames_from_header=5, - num_frames_from_content=10, - begin_stream_seconds_from_header=0, - begin_stream_seconds_from_content=begin_stream_seconds_from_content, - end_stream_seconds_from_content=end_stream_seconds_from_content, - codec="whatever", - width=123, - height=321, - pixel_aspect_ratio=Fraction(10, 11), - average_fps_from_header=5, - stream_index=0, + # Check that height and width are reported post-rotation + # For 90-degree rotation, width and height should be swapped + assert (decoder_rotated.metadata.height, decoder_rotated.metadata.width) == ( + decoder.metadata.width, + decoder.metadata.height, ) - assert metadata.duration_seconds == expected_duration_seconds +def test_color_metadata(): + # BT2020_LIMITED_RANGE_10BIT is a BT.2020 10-bit HEVC video with PQ transfer + decoder_bt2020 = VideoDecoder(BT2020_LIMITED_RANGE_10BIT.path) + assert decoder_bt2020.metadata.color_primaries == "bt2020" + assert decoder_bt2020.metadata.color_space == "bt2020nc" + assert decoder_bt2020.metadata.color_transfer_characteristic == "smpte2084" + assert decoder_bt2020.metadata.pixel_format == "yuv420p10le" -@pytest.mark.parametrize( - "num_frames_from_header, average_fps_from_header, expected_duration_seconds", - [(100, 10, 10), (100, None, None), (None, 10, None), (None, None, None)], -) -def test_calculate_duration_seconds_using_fps_and_num_frames( - num_frames_from_header, average_fps_from_header, expected_duration_seconds -): - """Check that duration_seconds is calculated using average_fps_from_header and num_frames_from_header - if duration_seconds_from_header is missing. - """ - metadata = VideoStreamMetadata( - duration_seconds_from_header=None, # None to test calculating duration_seconds - bit_rate=123, - num_frames_from_header=num_frames_from_header, - num_frames_from_content=10, - begin_stream_seconds_from_header=0, - begin_stream_seconds_from_content=None, # None to test calculating duration_seconds - end_stream_seconds_from_content=None, # None to test calculating duration_seconds - codec="whatever", - width=123, - height=321, - pixel_aspect_ratio=Fraction(10, 11), - average_fps_from_header=average_fps_from_header, - stream_index=0, - ) - assert metadata.duration_seconds_from_header is None - assert metadata.duration_seconds == expected_duration_seconds + # NASA_VIDEO has BT.709 color metadata + decoder_nasa = VideoDecoder(NASA_VIDEO.path) + assert decoder_nasa.metadata.color_primaries == "bt709" + assert decoder_nasa.metadata.color_space == "bt709" + assert decoder_nasa.metadata.color_transfer_characteristic == "bt709" + assert decoder_nasa.metadata.pixel_format == "yuv420p" def test_repr(): @@ -271,26 +195,30 @@ def test_repr(): str(VideoDecoder(NASA_VIDEO.path).metadata) == """VideoStreamMetadata: duration_seconds_from_header: 13.013 - begin_stream_seconds_from_header: 0.0 - bit_rate: 128783.0 + begin_stream_seconds_from_header: 0 + bit_rate: 128783 codec: h264 stream_index: 3 - begin_stream_seconds_from_content: 0.0 + duration_seconds: 13.013 + begin_stream_seconds: 0 + begin_stream_seconds_from_content: 0 end_stream_seconds_from_content: 13.013 width: 480 height: 270 num_frames_from_header: 390 num_frames_from_content: 390 - average_fps_from_header: 29.97003 + average_fps_from_header: 29.97002997002997 pixel_aspect_ratio: 1 - duration_seconds: 13.013 - begin_stream_seconds: 0.0 + rotation: None + color_primaries: bt709 + color_space: bt709 + color_transfer_characteristic: bt709 + pixel_format: yuv420p end_stream_seconds: 13.013 num_frames: 390 average_fps: 29.97002997002997 """ ) - ffmpeg_major_version = get_ffmpeg_major_version() expected_duration_seconds_from_header = ( 13.056 if ffmpeg_major_version >= 8 else 13.248 ) @@ -300,9 +228,11 @@ def test_repr(): == f"""AudioStreamMetadata: duration_seconds_from_header: {expected_duration_seconds_from_header} begin_stream_seconds_from_header: 0.138125 - bit_rate: 64000.0 + bit_rate: 64000 codec: mp3 stream_index: 0 + duration_seconds: {expected_duration_seconds_from_header} + begin_stream_seconds: 0.138125 sample_rate: 8000 num_channels: 2 sample_format: fltp diff --git a/test/test_ops.py b/test/test_ops.py index e798a7a2b..80dc4b41b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -10,7 +10,6 @@ os.environ["TORCH_LOGS"] = "output_code" import json -import subprocess import numpy as np import pytest @@ -18,18 +17,9 @@ import torch from torchcodec._core import ( - _add_video_stream, _test_frame_pts_equality, - add_audio_stream, - add_video_stream, - create_from_bytes, - create_from_file, - create_from_file_like, - create_from_tensor, + create_streaming_encoder_to_file, encode_audio_to_file, - encode_video_to_file, - encode_video_to_file_like, - encode_video_to_tensor, get_ffmpeg_library_versions, get_frame_at_index, get_frame_at_pts, @@ -40,26 +30,36 @@ get_frames_in_range, get_json_metadata, get_next_frame, + streaming_encoder_add_frames, + streaming_encoder_add_video_stream, + streaming_encoder_close, +) +from torchcodec._core.ops import ( + _add_video_stream, + add_audio_stream, + add_video_stream, + create_from_bytes, + create_from_file, + create_from_file_like, + create_from_tensor, seek_to_pts, ) + from torchcodec.decoders import VideoDecoder from .utils import ( all_supported_devices, assert_frames_equal, assert_tensor_close_on_at_least, - get_ffmpeg_major_version, - in_fbcode, - IS_WINDOWS, + get_python_version, NASA_AUDIO, NASA_AUDIO_MP3, NASA_VIDEO, needs_cuda, - psnr, + needs_ffmpeg_cli, SINE_MONO_S32, SINE_MONO_S32_44100, SINE_MONO_S32_8000, - TEST_SRC_2_720P, unsplit_device_str, ) @@ -378,6 +378,10 @@ def test_throws_exception_if_seek_too_far(self, device): with pytest.raises(IndexError, match="no more frames"): get_next_frame(decoder) + @pytest.mark.skipif( + get_python_version() >= (3, 14), + reason="torch.compile is not supported on Python 3.14+", + ) @pytest.mark.parametrize("device", all_supported_devices()) def test_compile_seek_and_next(self, device): # TODO_OPEN_ISSUE Scott (T180277797): Get this to work with the inductor stack. Right now @@ -505,10 +509,7 @@ def test_frame_pts_equality(self): ) assert pts_is_equal - @pytest.mark.skipif( - in_fbcode(), - reason="ffprobe not available internally", - ) + @needs_ffmpeg_cli def test_seek_mode_custom_frame_mappings_fails(self): with pytest.raises( RuntimeError, @@ -549,10 +550,7 @@ def test_seek_mode_custom_frame_mappings_fails(self): decoder, stream_index=0, custom_frame_mappings=different_lengths ) - @pytest.mark.skipif( - in_fbcode(), - reason="ffprobe not available internally", - ) + @needs_ffmpeg_cli @pytest.mark.parametrize("device", all_supported_devices()) def test_seek_mode_custom_frame_mappings(self, device): stream_index = 3 # custom_frame_index seek mode requires a stream index @@ -1016,7 +1014,7 @@ def seek(self, offset: int, whence: int) -> int: class SeekMethodMissing: def read(self, size: int) -> bytes: - return bytes() + return b"" with pytest.raises(RuntimeError, match="must implement a seek method"): create_from_file_like(SeekMethodMissing(), "approximate") @@ -1027,7 +1025,7 @@ def __init__(self, file: io.RawIOBase): # io.RawIOBase says we should accept a single int; wrong signature on purpose def read(self) -> bytes: - return bytes() + return b"" def seek(self, offset: int, whence: int) -> int: return self._file.seeK(offset, whence) @@ -1128,7 +1126,7 @@ def test_bad_input(self, tmp_path): valid_output_file = str(tmp_path / ".mp3") - with pytest.raises(RuntimeError, match="must have float32 dtype, got int"): + with pytest.raises(RuntimeError, match="must have float32 dtype, got Int"): encode_audio_to_file( samples=torch.arange(10, dtype=torch.int), sample_rate=10, @@ -1151,273 +1149,114 @@ def test_bad_input(self, tmp_path): ) -class TestVideoEncoderOps: - def decode(self, source=None) -> torch.Tensor: - return VideoDecoder(source).get_frames_in_range(start=0, stop=60) - - @pytest.mark.parametrize( - "format", ("mov", "mp4", "mkv", pytest.param("webm", marks=pytest.mark.slow)) - ) - @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) - def test_video_encoder_round_trip(self, tmp_path, format, method): - # Test that decode(encode(decode(frames))) == decode(frames) - ffmpeg_version = get_ffmpeg_major_version() - # In FFmpeg6, the default codec's best pixel format is lossy for all container formats but webm. - # As a result, we skip the round trip test. - if ffmpeg_version == 6 and format != "webm": - pytest.skip( - f"FFmpeg6 defaults to lossy encoding for {format}, skipping round-trip test." - ) - if format == "webm" and ( - ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7)) - ): - pytest.skip("Codec for webm is not available in this FFmpeg installation.") - source_frames = self.decode(TEST_SRC_2_720P.path).data - - params = dict( - frame_rate=30, crf=0 - ) # Frame rate is fixed with num frames decoded - if method == "to_file": - encoded_path = str(tmp_path / f"encoder_output.{format}") - encode_video_to_file( - frames=source_frames, - filename=encoded_path, - **params, - ) - round_trip_frames = self.decode(encoded_path).data - elif method == "to_tensor": - encoded_tensor = encode_video_to_tensor( - source_frames, format=format, **params - ) - round_trip_frames = self.decode(encoded_tensor).data - elif method == "to_file_like": - file_like = io.BytesIO() - encode_video_to_file_like( - frames=source_frames, - format=format, - file_like=file_like, - **params, - ) - round_trip_frames = self.decode(file_like.getvalue()).data - else: - raise ValueError(f"Unknown method: {method}") - - assert source_frames.shape == round_trip_frames.shape - assert source_frames.dtype == round_trip_frames.dtype - - # If FFmpeg selects a codec or pixel format that does lossy encoding, assert 99% of pixels - # are within a higher tolerance. - if ffmpeg_version == 6: - assert_close = partial(assert_tensor_close_on_at_least, percentage=99) - atol = 15 - else: - assert_close = torch.testing.assert_close - atol = 2 - for s_frame, rt_frame in zip(source_frames, round_trip_frames): - assert psnr(s_frame, rt_frame) > 30 - assert_close(s_frame, rt_frame, atol=atol, rtol=0) - - @pytest.mark.parametrize( - "format", - ( - "mov", - "mp4", - "avi", - "mkv", - "flv", - "gif", - pytest.param("webm", marks=pytest.mark.slow), - ), - ) - @pytest.mark.parametrize("method", ("to_tensor", "to_file_like")) - def test_against_to_file(self, tmp_path, format, method): - # Test that to_file, to_tensor, and to_file_like produce the same results - ffmpeg_version = get_ffmpeg_major_version() - if format == "webm" and ( - ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7)) - ): - pytest.skip("Codec for webm is not available in this FFmpeg installation.") - - source_frames = self.decode(TEST_SRC_2_720P.path).data - params = dict(frame_rate=30, crf=0) - - encoded_file = tmp_path / f"output.{format}" - encode_video_to_file(frames=source_frames, filename=str(encoded_file), **params) - - if method == "to_tensor": - encoded_output = encode_video_to_tensor( - source_frames, format=format, **params - ) - else: # to_file_like - file_like = io.BytesIO() - encode_video_to_file_like( - frames=source_frames, - file_like=file_like, - format=format, - **params, - ) - encoded_output = file_like.getvalue() - - torch.testing.assert_close( - self.decode(encoded_file).data, - self.decode(encoded_output).data, - atol=0, - rtol=0, - ) +class TestMultiStreamEncoderOps: + def test_double_close(self, tmp_path): + encoder_tensor = create_streaming_encoder_to_file(str(tmp_path / "test.mp4")) + streaming_encoder_close(encoder_tensor) + streaming_encoder_close(encoder_tensor) # double close is a no-op - @pytest.mark.skipif(in_fbcode(), reason="ffmpeg CLI not available") - @pytest.mark.parametrize( - "format", - ( - "mov", - "mp4", - "avi", - "mkv", - "flv", - "gif", - pytest.param("webm", marks=pytest.mark.slow), - ), - ) - def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format): - ffmpeg_version = get_ffmpeg_major_version() - if format == "webm" and ( - ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7)) - ): - pytest.skip("Codec for webm is not available in this FFmpeg installation.") - - source_frames = self.decode(TEST_SRC_2_720P.path).data - - # Encode with FFmpeg CLI - temp_raw_path = str(tmp_path / "temp_input.raw") - with open(temp_raw_path, "wb") as f: - f.write(source_frames.permute(0, 2, 3, 1).cpu().numpy().tobytes()) - - ffmpeg_encoded_path = str(tmp_path / f"ffmpeg_output.{format}") - frame_rate = 30 - crf = 0 - # Some codecs (ex. MPEG4) do not support CRF. - # Flags not supported by the selected codec will be ignored. - ffmpeg_cmd = [ - "ffmpeg", - "-y", - "-f", - "rawvideo", - "-pix_fmt", - "rgb24", - "-s", - f"{source_frames.shape[3]}x{source_frames.shape[2]}", - "-r", - str(frame_rate), - "-i", - temp_raw_path, - "-crf", - str(crf), - ffmpeg_encoded_path, - ] - subprocess.run(ffmpeg_cmd, check=True) + @pytest.mark.parametrize("format", ["mp4", "mov", "mkv"]) + def test_add_video_stream_and_encode_frames(self, tmp_path, format): + source_decoder = VideoDecoder(str(NASA_VIDEO.path)) + source_frames = source_decoder.get_frames_in_range(start=0, stop=10).data + frame_rate = source_decoder.metadata.average_fps - # Encode with our video encoder - encoder_output_path = str(tmp_path / f"encoder_output.{format}") - encode_video_to_file( - frames=source_frames, + output_file = str(tmp_path / f"test.{format}") + encoder = create_streaming_encoder_to_file(output_file) + streaming_encoder_add_video_stream( + encoder, frame_rate=frame_rate, - filename=encoder_output_path, - crf=crf, + pixel_format="yuv444p", + crf=0, ) + streaming_encoder_add_frames(encoder, source_frames[:5]) + streaming_encoder_add_frames(encoder, source_frames[5:]) + streaming_encoder_close(encoder) - ffmpeg_frames = self.decode(ffmpeg_encoded_path).data - encoder_frames = self.decode(encoder_output_path).data - - assert ffmpeg_frames.shape[0] == encoder_frames.shape[0] - - # If FFmpeg selects a codec or pixel format that uses qscale (not crf), - # the VideoEncoder outputs *slightly* different frames. - # There may be additional subtle differences in the encoder. - percentage = 94 if ffmpeg_version == 6 or format == "avi" else 99 - - # Check that PSNR between both encoded versions is high - for ff_frame, enc_frame in zip(ffmpeg_frames, encoder_frames): - res = psnr(ff_frame, enc_frame) - assert res > 30 - assert_tensor_close_on_at_least( - ff_frame, enc_frame, percentage=percentage, atol=2 - ) - - def test_to_file_like_custom_file_object(self): - """Test to_file_like with a custom file-like object that implements write and seek.""" - - class CustomFileObject: - def __init__(self): - self._file = io.BytesIO() - - def write(self, data): - return self._file.write(data) - - def seek(self, offset, whence=0): - return self._file.seek(offset, whence) - - def get_encoded_data(self): - return self._file.getvalue() - - source_frames = self.decode(TEST_SRC_2_720P.path).data - file_like = CustomFileObject() - encode_video_to_file_like( - source_frames, frame_rate=30, crf=0, format="mp4", file_like=file_like + decoded_frames = ( + VideoDecoder(output_file).get_frames_in_range(start=0, stop=10).data ) - decoded_samples = self.decode(file_like.get_encoded_data()) - - torch.testing.assert_close( - decoded_samples.data, - source_frames, - atol=2, - rtol=0, + assert_tensor_close_on_at_least( + decoded_frames, source_frames, percentage=99, atol=2 ) - def test_to_file_like_real_file(self, tmp_path): - """Test to_file_like with a real file opened in binary write mode.""" - source_frames = self.decode(TEST_SRC_2_720P.path).data - file_path = tmp_path / "test_file_like.mp4" + def test_create_invalid_path(self): + with pytest.raises(RuntimeError, match="make sure it's a valid path"): + create_streaming_encoder_to_file("/nonexistent/dir/test.mp4") - with open(file_path, "wb") as file_like: - encode_video_to_file_like( - source_frames, frame_rate=30, crf=0, format="mp4", file_like=file_like + def test_create_invalid_format(self, tmp_path): + with pytest.raises(RuntimeError, match="check the desired extension"): + create_streaming_encoder_to_file(str(tmp_path / "test.bad_extension")) + + @pytest.mark.parametrize("format", ["mp4", "mov"]) + def test_fragmented_mp4(self, format, tmp_path): + source_decoder = VideoDecoder(str(NASA_VIDEO.path)) + source_frames = source_decoder.get_frames_in_range(start=0, stop=10).data + frame_rate = source_decoder.metadata.average_fps + + output_file = str(tmp_path / f"test.{format}") + encoder = create_streaming_encoder_to_file(output_file) + streaming_encoder_add_video_stream( + encoder, + frame_rate=frame_rate, + pixel_format="yuv444p", + crf=0, + # In addition to the fragmentation flag, I found "flush_packets" and "threads" to be necessary to decode frames before close(). + # See other frag flags: https://ffmpeg.org/ffmpeg-formats.html#Fragmentation + # TODO MultiStreamEncoder: Get a better understanding of which options are necessary for reading fragmented mp4s + extra_options=[ + "movflags", + "+frag_every_frame+empty_moov", + "tune", + "zerolatency", + "flush_packets", + "1", + "threads", + "1", + ], + ) + # Here, we decode the available fragmented mp4 frames before calling close() + for batch in [source_frames[:5], source_frames[5:]]: + streaming_encoder_add_frames(encoder, batch) + mid_decoder = VideoDecoder(output_file) + num_available = len(mid_decoder) + assert num_available > 0 + assert_tensor_close_on_at_least( + mid_decoder.get_frames_in_range(start=0, stop=num_available).data, + source_frames[:num_available], + percentage=99, + atol=2, ) - decoded_samples = self.decode(str(file_path)) - torch.testing.assert_close( - decoded_samples.data, + streaming_encoder_close(encoder) + # After close, all frames must be decodable + assert_tensor_close_on_at_least( + VideoDecoder(output_file).get_frames_in_range(start=0, stop=10).data, source_frames, + percentage=99, atol=2, - rtol=0, ) - def test_to_file_like_bad_methods(self): - source_frames = self.decode(TEST_SRC_2_720P.path).data - - class NoWriteMethod: - def seek(self, offset, whence=0): - return 0 - - with pytest.raises( - RuntimeError, match="File like object must implement a write method" - ): - encode_video_to_file_like( - source_frames, - frame_rate=30, - format="mp4", - file_like=NoWriteMethod(), - ) - - class NoSeekMethod: - def write(self, data): - return len(data) - - with pytest.raises( - RuntimeError, match="File like object must implement a seek method" - ): - encode_video_to_file_like( - source_frames, frame_rate=30, format="mp4", file_like=NoSeekMethod() - ) + def test_add_video_stream_twice_errors(self, tmp_path): + encoder = create_streaming_encoder_to_file(str(tmp_path / "test.mp4")) + streaming_encoder_add_video_stream(encoder, frame_rate=30.0) + with pytest.raises(RuntimeError, match="already been added"): + streaming_encoder_add_video_stream(encoder, frame_rate=24.0) + + def test_add_frames_different_sizes_errors(self, tmp_path): + encoder = create_streaming_encoder_to_file(str(tmp_path / "test.mp4")) + streaming_encoder_add_video_stream(encoder, frame_rate=30.0) + frames_64 = torch.randint(0, 256, (2, 3, 64, 64), dtype=torch.uint8) + frames_128 = torch.randint(0, 256, (2, 3, 128, 128), dtype=torch.uint8) + streaming_encoder_add_frames(encoder, frames_64) + with pytest.raises(RuntimeError, match="same dimensions"): + streaming_encoder_add_frames(encoder, frames_128) + + def test_add_frames_without_stream_errors(self, tmp_path): + encoder = create_streaming_encoder_to_file(str(tmp_path / "test.mp4")) + frames = torch.randint(0, 256, (5, 3, 64, 64), dtype=torch.uint8) + with pytest.raises(RuntimeError, match="No video stream"): + streaming_encoder_add_frames(encoder, frames) if __name__ == "__main__": diff --git a/test/test_samplers.py b/test/test_samplers.py index 10c529062..a8d281252 100644 --- a/test/test_samplers.py +++ b/test/test_samplers.py @@ -558,12 +558,12 @@ def test_time_based_sampler_errors(sampler): decoder = VideoDecoder(NASA_VIDEO.path) with pytest.raises( - ValueError, match=re.escape("sampling_range_start (-1) must be at least 0.0") + ValueError, match=re.escape("sampling_range_start (-1) must be at least 0") ): sampler(decoder, sampling_range_start=-1) with pytest.raises( - ValueError, match=re.escape("sampling_range_end (-1) must be at least 0.0") + ValueError, match=re.escape("sampling_range_end (-1) must be at least 0") ): sampler(decoder, sampling_range_end=-1) @@ -595,6 +595,7 @@ def restore_metadata(): decoder.metadata.num_frames_from_header = ( None # Set to none to prevent fallback calculation ) + decoder.metadata.end_stream_seconds = None with pytest.raises( ValueError, match="Could not infer stream end from video metadata" ): @@ -603,6 +604,7 @@ def restore_metadata(): with restore_metadata(): decoder.metadata.end_stream_seconds_from_content = None decoder.metadata.average_fps_from_header = None + decoder.metadata.average_fps = None with pytest.raises(ValueError, match="Could not infer average fps"): sampler(decoder) diff --git a/test/test_transform_ops.py b/test/test_transform_ops.py index 8d1ba5e53..369d55959 100644 --- a/test/test_transform_ops.py +++ b/test/test_transform_ops.py @@ -13,74 +13,422 @@ import pytest import torch +import torchcodec -from torchcodec._core import ( - _add_video_stream, - add_video_stream, - create_from_file, - get_frame_at_index, - get_json_metadata, - get_next_frame, -) +from torchcodec._core import get_frame_at_index, get_json_metadata +from torchcodec._core.ops import _add_video_stream, add_video_stream, create_from_file +from torchcodec.decoders import VideoDecoder from torchvision.transforms import v2 -from .utils import assert_frames_equal, NASA_VIDEO, needs_cuda +from .utils import ( + assert_frames_equal, + assert_tensor_close_on_at_least, + AV1_VIDEO, + get_ffmpeg_minor_version, + H265_VIDEO, + NASA_VIDEO, + needs_cuda, + TEST_NON_ZERO_START as NON_32_ALIGNED_WIDTH_VIDEO, + TEST_SRC_2_720P, +) + + +class TestPublicVideoDecoderTransformOps: + @pytest.mark.parametrize( + "height_scaling_factor, width_scaling_factor", + ((1.5, 1.31), (0.5, 0.71), (0.7, 1.31), (1.5, 0.71), (1.0, 1.0), (2.0, 2.0)), + ) + @pytest.mark.parametrize("video", [NASA_VIDEO, TEST_SRC_2_720P]) + def test_resize_torchvision( + self, video, height_scaling_factor, width_scaling_factor + ): + height = int(video.get_height() * height_scaling_factor) + width = int(video.get_width() * width_scaling_factor) + + # We're using both the TorchCodec object and the TorchVision object to + # ensure that they specify exactly the same thing. + decoder_resize = VideoDecoder( + video.path, transforms=[torchcodec.transforms.Resize(size=(height, width))] + ) + decoder_resize_tv = VideoDecoder( + video.path, transforms=[v2.Resize(size=(height, width))] + ) + + decoder_full = VideoDecoder(video.path) + + num_frames = len(decoder_resize) + assert num_frames == len(decoder_full) + + for frame_index in [ + 0, + int(num_frames * 0.1), + int(num_frames * 0.2), + int(num_frames * 0.3), + int(num_frames * 0.4), + int(num_frames * 0.5), + int(num_frames * 0.75), + int(num_frames * 0.90), + num_frames - 1, + ]: + frame_resize_tv = decoder_resize_tv[frame_index] + frame_resize = decoder_resize[frame_index] + assert_frames_equal(frame_resize_tv, frame_resize) + + frame_full = decoder_full[frame_index] + + frame_tv = v2.functional.resize(frame_full, size=(height, width)) + frame_tv_no_antialias = v2.functional.resize( + frame_full, size=(height, width), antialias=False + ) + + expected_shape = (video.get_num_color_channels(), height, width) + assert frame_resize.shape == expected_shape + assert frame_tv.shape == expected_shape + assert frame_tv_no_antialias.shape == expected_shape -torch._dynamo.config.capture_dynamic_output_shape_ops = True + assert_tensor_close_on_at_least( + frame_resize, frame_tv, percentage=99.8, atol=1 + ) + torch.testing.assert_close(frame_resize, frame_tv, rtol=0, atol=6) + + if height_scaling_factor < 1 or width_scaling_factor < 1: + # Antialias only relevant when down-scaling! + with pytest.raises(AssertionError, match="Expected at least"): + assert_tensor_close_on_at_least( + frame_resize, frame_tv_no_antialias, percentage=99, atol=1 + ) + with pytest.raises(AssertionError, match="Tensor-likes are not close"): + torch.testing.assert_close( + frame_resize, frame_tv_no_antialias, rtol=0, atol=6 + ) + + def test_resize_fails(self): + with pytest.raises( + ValueError, + match=r"must use bilinear interpolation", + ): + VideoDecoder( + NASA_VIDEO.path, + transforms=[ + v2.Resize( + size=(100, 100), interpolation=v2.InterpolationMode.BICUBIC + ) + ], + ) + with pytest.raises( + ValueError, + match=r"must have antialias enabled", + ): + VideoDecoder( + NASA_VIDEO.path, + transforms=[v2.Resize(size=(100, 100), antialias=False)], + ) + + with pytest.raises( + ValueError, + match=r"must have a size specified", + ): + VideoDecoder( + NASA_VIDEO.path, transforms=[v2.Resize(size=None, max_size=100)] + ) + + with pytest.raises( + ValueError, + match=r"must have a \(height, width\) pair for the size", + ): + VideoDecoder(NASA_VIDEO.path, transforms=[v2.Resize(size=(100))]) + + with pytest.raises( + ValueError, + match=r"must have a \(height, width\) pair for the size", + ): + VideoDecoder( + NASA_VIDEO.path, + transforms=[torchcodec.transforms.Resize(size=(100, 100, 100))], + ) + + def test_resize_non_32_aligned_input_width(self): + assert NON_32_ALIGNED_WIDTH_VIDEO.get_width() % 32 != 0 + decoder = VideoDecoder( + NON_32_ALIGNED_WIDTH_VIDEO.path, + transforms=[torchcodec.transforms.Resize(size=(224, 224))], + ) + assert decoder[0].shape == (3, 224, 224) -class TestVideoDecoderTransformOps: - # We choose arbitrary values for width and height scaling to get better - # test coverage. Some pairs upscale the image while others downscale it. @pytest.mark.parametrize( - "width_scaling_factor,height_scaling_factor", - ((1.31, 1.5), (0.71, 0.5), (1.31, 0.7), (0.71, 1.5), (1.0, 1.0)), + "height_scaling_factor, width_scaling_factor", + ((0.5, 0.5), (0.25, 0.1), (1.0, 1.0), (0.15, 0.75)), ) - @pytest.mark.parametrize("input_video", [NASA_VIDEO]) - def test_color_conversion_library_with_scaling( - self, input_video, width_scaling_factor, height_scaling_factor + @pytest.mark.parametrize("video", [NASA_VIDEO, TEST_SRC_2_720P]) + def test_center_crop_torchvision( + self, + height_scaling_factor, + width_scaling_factor, + video, ): - decoder = create_from_file(str(input_video.path)) + height = int(video.get_height() * height_scaling_factor) + width = int(video.get_width() * width_scaling_factor) + + tc_center_crop = torchcodec.transforms.CenterCrop(size=(height, width)) + decoder_center_crop = VideoDecoder(video.path, transforms=[tc_center_crop]) + + decoder_center_crop_tv = VideoDecoder( + video.path, + transforms=[v2.CenterCrop(size=(height, width))], + ) + + decoder_full = VideoDecoder(video.path) + + num_frames = len(decoder_center_crop_tv) + assert num_frames == len(decoder_full) + + for frame_index in [ + 0, + int(num_frames * 0.25), + int(num_frames * 0.5), + int(num_frames * 0.75), + num_frames - 1, + ]: + frame_center_crop = decoder_center_crop[frame_index] + frame_center_crop_tv = decoder_center_crop_tv[frame_index] + assert_frames_equal(frame_center_crop, frame_center_crop_tv) + + expected_shape = (video.get_num_color_channels(), height, width) + assert frame_center_crop_tv.shape == expected_shape + + frame_full = decoder_full[frame_index] + frame_tv = v2.CenterCrop(size=(height, width))(frame_full) + assert_frames_equal(frame_center_crop, frame_tv) + + def test_center_crop_fails(self): + with pytest.raises( + ValueError, + match=r"must have a \(height, width\) pair for the size", + ): + VideoDecoder( + NASA_VIDEO.path, + transforms=[torchcodec.transforms.CenterCrop(size=(100,))], + ) + + @pytest.mark.parametrize( + "height_scaling_factor, width_scaling_factor", + ((0.5, 0.5), (0.25, 0.1), (1.0, 1.0), (0.15, 0.75)), + ) + @pytest.mark.parametrize("video", [NASA_VIDEO, TEST_SRC_2_720P]) + @pytest.mark.parametrize("seed", [0, 1234]) + def test_random_crop_torchvision( + self, + height_scaling_factor, + width_scaling_factor, + video, + seed, + ): + height = int(video.get_height() * height_scaling_factor) + width = int(video.get_width() * width_scaling_factor) + + # We want both kinds of RandomCrop objects to get arrive at the same + # locations to crop, so we need to make sure they get the same random + # seed. It's used in RandomCrop's _make_transform_spec() method, called + # by the VideoDecoder. + torch.manual_seed(seed) + tc_random_crop = torchcodec.transforms.RandomCrop(size=(height, width)) + decoder_random_crop = VideoDecoder(video.path, transforms=[tc_random_crop]) + + # Resetting manual seed for when TorchCodec's RandomCrop, created from + # the TorchVision RandomCrop, is used inside of the VideoDecoder. It + # needs to match the call above. + torch.manual_seed(seed) + decoder_random_crop_tv = VideoDecoder( + video.path, + transforms=[v2.RandomCrop(size=(height, width))], + ) + + decoder_full = VideoDecoder(video.path) + + num_frames = len(decoder_random_crop_tv) + assert num_frames == len(decoder_full) + + for frame_index in [ + 0, + int(num_frames * 0.25), + int(num_frames * 0.5), + int(num_frames * 0.75), + num_frames - 1, + ]: + frame_random_crop = decoder_random_crop[frame_index] + frame_random_crop_tv = decoder_random_crop_tv[frame_index] + assert_frames_equal(frame_random_crop, frame_random_crop_tv) + + expected_shape = (video.get_num_color_channels(), height, width) + assert frame_random_crop_tv.shape == expected_shape + + # Resetting manual seed to make sure the invocation of the + # TorchVision RandomCrop matches the two calls above. + torch.manual_seed(seed) + frame_full = decoder_full[frame_index] + frame_tv = v2.RandomCrop(size=(height, width))(frame_full) + assert_frames_equal(frame_random_crop, frame_tv) + + @pytest.mark.parametrize( + "height_scaling_factor, width_scaling_factor", + ((0.25, 0.1), (0.25, 0.25)), + ) + def test_random_crop_nhwc( + self, + height_scaling_factor, + width_scaling_factor, + ): + height = int(TEST_SRC_2_720P.get_height() * height_scaling_factor) + width = int(TEST_SRC_2_720P.get_width() * width_scaling_factor) + + decoder = VideoDecoder( + TEST_SRC_2_720P.path, + transforms=[torchcodec.transforms.RandomCrop(size=(height, width))], + dimension_order="NHWC", + ) + + num_frames = len(decoder) + for frame_index in [ + 0, + int(num_frames * 0.25), + int(num_frames * 0.5), + int(num_frames * 0.75), + num_frames - 1, + ]: + frame = decoder[frame_index] + assert frame.shape == (height, width, 3) + + @pytest.mark.parametrize( + "error_message, params", + ( + ("must not specify padding", dict(size=(100, 100), padding=255)), + ( + "must not specify pad_if_needed", + dict(size=(100, 100), pad_if_needed=True), + ), + ("fill must be 0", dict(size=(100, 100), fill=255)), + ( + "padding_mode must be constant", + dict(size=(100, 100), padding_mode="edge"), + ), + ), + ) + def test_random_crop_fails(self, error_message, params): + with pytest.raises( + ValueError, + match=error_message, + ): + VideoDecoder( + NASA_VIDEO.path, + transforms=[v2.RandomCrop(**params)], + ) + + @pytest.mark.parametrize("seed", [0, 314]) + def test_random_crop_reusable_objects(self, seed): + torch.manual_seed(seed) + random_crop = torchcodec.transforms.RandomCrop(size=(99, 99)) + + # Create a spec which causes us to calculate the random crop location. + first_spec = random_crop._make_transform_spec((888, 888)) + + # Create a spec again, which should calculate a different random crop + # location. Despite having the same image size, the specs should be + # different because the crop should be at a different location + second_spec = random_crop._make_transform_spec((888, 888)) + assert first_spec != second_spec + + # Create a spec again, but with a different image size. The specs should + # obviously be different, but the original image size should not be in + # the spec at all. + third_spec = random_crop._make_transform_spec((777, 777)) + assert third_spec != first_spec + assert "888" not in third_spec + + @pytest.mark.parametrize( + "resize, random_crop", + [ + (torchcodec.transforms.Resize, torchcodec.transforms.RandomCrop), + (v2.Resize, v2.RandomCrop), + ], + ) + def test_transform_pipeline(self, resize, random_crop): + decoder = VideoDecoder( + TEST_SRC_2_720P.path, + transforms=[ + # resized to bigger than original + resize(size=(2160, 3840)), + # crop to smaller than the resize, but still bigger than original + random_crop(size=(1080, 1920)), + ], + ) + + num_frames = len(decoder) + for frame_index in [ + 0, + int(num_frames * 0.25), + int(num_frames * 0.5), + int(num_frames * 0.75), + num_frames - 1, + ]: + frame = decoder[frame_index] + assert frame.shape == (TEST_SRC_2_720P.get_num_color_channels(), 1080, 1920) + + def test_transform_fails(self): + with pytest.raises( + ValueError, + match="Unsupported transform", + ): + VideoDecoder(NASA_VIDEO.path, transforms=[v2.RandomHorizontalFlip(p=1.0)]) + + +class TestCoreVideoDecoderTransformOps: + def get_num_frames_core_ops(self, video): + decoder = create_from_file(str(video.path)) add_video_stream(decoder) metadata = get_json_metadata(decoder) metadata_dict = json.loads(metadata) - assert metadata_dict["width"] == input_video.width - assert metadata_dict["height"] == input_video.height + num_frames = metadata_dict["numFramesFromHeader"] + assert num_frames is not None + return num_frames - target_height = int(input_video.height * height_scaling_factor) - target_width = int(input_video.width * width_scaling_factor) - if width_scaling_factor != 1.0: - assert target_width != input_video.width - if height_scaling_factor != 1.0: - assert target_height != input_video.height + @pytest.mark.parametrize("video", [NASA_VIDEO, H265_VIDEO, AV1_VIDEO]) + def test_color_conversion_library(self, video): + num_frames = self.get_num_frames_core_ops(video) - filtergraph_decoder = create_from_file(str(input_video.path)) + filtergraph_decoder = create_from_file(str(video.path)) _add_video_stream( filtergraph_decoder, - transform_specs=f"resize, {target_height}, {target_width}", color_conversion_library="filtergraph", ) - filtergraph_frame0, _, _ = get_next_frame(filtergraph_decoder) - swscale_decoder = create_from_file(str(input_video.path)) + swscale_decoder = create_from_file(str(video.path)) _add_video_stream( swscale_decoder, - transform_specs=f"resize, {target_height}, {target_width}", color_conversion_library="swscale", ) - swscale_frame0, _, _ = get_next_frame(swscale_decoder) - assert_frames_equal(filtergraph_frame0, swscale_frame0) - assert filtergraph_frame0.shape == (3, target_height, target_width) - @pytest.mark.parametrize( - "width_scaling_factor,height_scaling_factor", - ((1.31, 1.5), (0.71, 0.5), (1.31, 0.7), (0.71, 1.5), (1.0, 1.0)), - ) + for frame_index in [ + 0, + int(num_frames * 0.25), + int(num_frames * 0.5), + int(num_frames * 0.75), + num_frames - 1, + ]: + filtergraph_frame, *_ = get_frame_at_index( + filtergraph_decoder, frame_index=frame_index + ) + swscale_frame, *_ = get_frame_at_index( + swscale_decoder, frame_index=frame_index + ) + + assert_frames_equal(filtergraph_frame, swscale_frame) + @pytest.mark.parametrize("width", [30, 32, 300]) @pytest.mark.parametrize("height", [128]) def test_color_conversion_library_with_generated_videos( - self, tmp_path, width, height, width_scaling_factor, height_scaling_factor + self, tmp_path, width, height ): # We consider filtergraph to be the reference color conversion library. # However the video decoder sometimes uses swscale as that is faster. @@ -129,27 +477,22 @@ def test_color_conversion_library_with_generated_videos( assert metadata_dict["width"] == width assert metadata_dict["height"] == height - target_height = int(height * height_scaling_factor) - target_width = int(width * width_scaling_factor) - if width_scaling_factor != 1.0: - assert target_width != width - if height_scaling_factor != 1.0: - assert target_height != height + num_frames = metadata_dict["numFramesFromHeader"] + assert num_frames is not None and num_frames == 1 filtergraph_decoder = create_from_file(str(video_path)) _add_video_stream( filtergraph_decoder, - transform_specs=f"resize, {target_height}, {target_width}", color_conversion_library="filtergraph", ) - filtergraph_frame0, _, _ = get_next_frame(filtergraph_decoder) auto_decoder = create_from_file(str(video_path)) add_video_stream( auto_decoder, - transform_specs=f"resize, {target_height}, {target_width}", ) - auto_frame0, _, _ = get_next_frame(auto_decoder) + + filtergraph_frame0, *_ = get_frame_at_index(filtergraph_decoder, frame_index=0) + auto_frame0, *_ = get_frame_at_index(auto_decoder, frame_index=0) assert_frames_equal(filtergraph_frame0, auto_frame0) @needs_cuda @@ -175,6 +518,34 @@ def test_transform_fails(self): ): add_video_stream(decoder, transform_specs="invalid, 1, 2") + def test_resize_ffmpeg(self): + height = 135 + width = 240 + expected_shape = (NASA_VIDEO.get_num_color_channels(), height, width) + resize_spec = f"resize, {height}, {width}" + resize_filtergraph = f"scale={width}:{height}:flags=bilinear" + + decoder_resize = create_from_file(str(NASA_VIDEO.path)) + add_video_stream(decoder_resize, transform_specs=resize_spec) + + for frame_index in [17, 230, 389]: + frame_resize, *_ = get_frame_at_index( + decoder_resize, frame_index=frame_index + ) + frame_ref = NASA_VIDEO.get_frame_data_by_index( + frame_index, filters=resize_filtergraph + ) + + assert frame_resize.shape == expected_shape + assert frame_ref.shape == expected_shape + + if torchcodec.ffmpeg_major_version <= 4 and get_ffmpeg_minor_version() <= 1: + # FFmpeg version 4.1 and before appear to have a different + # resize implementation. + torch.testing.assert_close(frame_resize, frame_ref, rtol=0, atol=2) + else: + assert_frames_equal(frame_resize, frame_ref) + def test_resize_transform_fails(self): decoder = create_from_file(str(NASA_VIDEO.path)) with pytest.raises( @@ -224,7 +595,7 @@ def test_crop_transform(self): add_video_stream(decoder_full) for frame_index in [0, 15, 200, 389]: - frame, *_ = get_frame_at_index(decoder_crop, frame_index=frame_index) + frame_crop, *_ = get_frame_at_index(decoder_crop, frame_index=frame_index) frame_ref = NASA_VIDEO.get_frame_data_by_index( frame_index, filters=crop_filtergraph ) @@ -234,12 +605,12 @@ def test_crop_transform(self): frame_full, top=y, left=x, height=height, width=width ) - assert frame.shape == expected_shape + assert frame_crop.shape == expected_shape assert frame_ref.shape == expected_shape assert frame_tv.shape == expected_shape - assert_frames_equal(frame, frame_tv) - assert_frames_equal(frame, frame_ref) + assert_frames_equal(frame_crop, frame_ref) + assert_frames_equal(frame_crop, frame_tv) def test_crop_transform_fails(self): @@ -266,14 +637,14 @@ def test_crop_transform_fails(self): with pytest.raises( RuntimeError, - match="x position out of bounds", + match="x start position, 9999, out of bounds", ): decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, transform_specs="crop, 100, 100, 9999, 100") with pytest.raises( RuntimeError, - match="y position out of bounds", + match=r"Crop output height \(999\) is greater than input height \(270\)", ): decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, transform_specs="crop, 999, 100, 100, 100") diff --git a/test/third-party-interface/CMakeLists.txt b/test/third-party-interface/CMakeLists.txt new file mode 100644 index 000000000..d3d732016 --- /dev/null +++ b/test/third-party-interface/CMakeLists.txt @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# See test_third_party_interface.py for context + +cmake_minimum_required(VERSION 3.18) +project(ThirdPartyInterfaceTest) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +find_package(Torch REQUIRED) +find_package(TorchCodec REQUIRED) + +function(make_torchcodec_targets torchcodec_variant) + set(libname "torchcodec_third_party_interface_test${torchcodec_variant}") + set(sources ThirdPartyInterfaceTest.cpp) + + # Building executable to check the linkage + add_library(${libname} SHARED ${sources}) + set_target_properties(${libname} PROPERTIES PREFIX "") + set_target_properties(${libname} PROPERTIES CXX_STANDARD 17) + target_link_libraries(${libname} + ${TORCH_LIBRARIES} + torchcodec::core${torchcodec_variant} + torchcodec::ffmpeg${torchcodec_variant} + ) +endfunction() + +foreach(variant IN LISTS TORCHCODEC_VARIANTS) + make_torchcodec_targets(${variant}) +endforeach() diff --git a/test/third-party-interface/ThirdPartyInterfaceTest.cpp b/test/third-party-interface/ThirdPartyInterfaceTest.cpp new file mode 100644 index 000000000..1623c797f --- /dev/null +++ b/test/third-party-interface/ThirdPartyInterfaceTest.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +// See test_third_party_interface.py for context. +#include "DeviceInterface.h" +#include "FilterGraph.h" + +namespace facebook::torchcodec { + +class DummyDeviceInterface : public DeviceInterface { + public: + DummyDeviceInterface(const StableDevice& device) : DeviceInterface(device) {} + + virtual ~DummyDeviceInterface() {} + + void initialize( + const AVStream* avStream, + const UniqueDecodingAVFormatContext& avFormatCtx, + const SharedAVCodecContext& codecContext) override {} + + void convertAVFrameToFrameOutput( + UniqueAVFrame& avFrame, + FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor = + std::nullopt) override {} + + private: + std::unique_ptr filterGraphContext_; +}; + +namespace { +static bool g_dummy = registerDeviceInterface( + DeviceInterfaceKey(StableDeviceType::PrivateUse1), + [](const StableDevice& device) { + return new DummyDeviceInterface(device); + }); +} // namespace +} // namespace facebook::torchcodec diff --git a/test/third-party-interface/test_third_party_interface.py b/test/third-party-interface/test_third_party_interface.py new file mode 100644 index 000000000..bc6b5ea36 --- /dev/null +++ b/test/third-party-interface/test_third_party_interface.py @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +We allow third-parties to build their own C++ TorchCodec extensions via the DeviceInterface API. +This test ensures that such third-party extensions can be built correctly. +""" + +import os +import subprocess + +from pathlib import Path + +import torch +import torchcodec + + +def test_third_party_interface_pkgconfig(tmp_path): + # Test building of third-party-interface. Since + # TORCHCODEC_FFMPEG{ver}_INSTALL_PREFIX is not provided, FFmpeg should be + # found via pkg-config + cmake_args = [ + "cmake", + "-DCMAKE_BUILD_TYPE=Debug", + "-DCMAKE_VERBOSE_MAKEFILE=ON", + f"-DCMAKE_PREFIX_PATH={torchcodec.cmake_prefix_path};{torch.utils.cmake_prefix_path}", + Path(__file__).parent, + ] + result = subprocess.run(cmake_args, cwd=tmp_path) + assert result.returncode == 0 + + result = subprocess.run(["cmake", "--build", "."], cwd=tmp_path) + assert result.returncode == 0 + + # loading built .so in the separate process to avoid flooding current process + ver = f"{torchcodec.ffmpeg_major_version}" + result = subprocess.run( + [ + "python3", + "-c", + f"import torch; torch.ops.load_library('torchcodec_third_party_interface_test{ver}.so')", + ], + cwd=tmp_path, + ) + assert result.returncode == 0 + + +def test_third_party_interface_fails_when_no_ffmpeg(tmp_path): + # Test that passing non-existing TORCHCODEC_FFMPEG{ver}_INSTALL_PREFIX + # makes cmake configuration fail + cmake_args = [ + "cmake", + "-DCMAKE_BUILD_TYPE=Debug", + "-DCMAKE_VERBOSE_MAKEFILE=ON", + f"-DCMAKE_PREFIX_PATH={torchcodec.cmake_prefix_path};{torch.utils.cmake_prefix_path}", + Path(__file__).parent, + ] + ver = f"{torchcodec.ffmpeg_major_version}" + my_env = os.environ.copy() + my_env[f"TORCHCODEC_FFMPEG{ver}_INSTALL_PREFIX"] = ( + Path(__file__).parent / "no-such-dir" + ) + + # cmake config should fail as we've set ffmpeg install prefix to the not existing + # directory + result = subprocess.run(cmake_args, cwd=tmp_path, env=my_env) + assert result.returncode != 0 + + +def test_third_party_interface_with_prefix(tmp_path): + # Test that passing a valid TORCHCODEC_FFMPEG{ver}_INSTALL_PREFIX uses those + # FFmpeg libraries. + cmake_args = [ + "cmake", + "-DCMAKE_BUILD_TYPE=Debug", + "-DCMAKE_VERBOSE_MAKEFILE=ON", + f"-DCMAKE_PREFIX_PATH={torchcodec.cmake_prefix_path};{torch.utils.cmake_prefix_path}", + Path(__file__).parent, + ] + + # In this test we are calculating the prefix of installed ffmpeg version from the location + # of its libavcodec.pc file. Potentially, on the custom ffmpeg install with custom layout + # our calculation can be wrong and test might fail. + result = subprocess.run( + ["pkg-config", "--variable=prefix", "libavcodec"], + capture_output=True, + text=True, + ) + assert result.returncode == 0 + + ver = f"{torchcodec.ffmpeg_major_version}" + my_env = os.environ.copy() + my_env[f"TORCHCODEC_FFMPEG{ver}_INSTALL_PREFIX"] = Path(f"{result.stdout.strip()}") + + result = subprocess.run(cmake_args, cwd=tmp_path, env=my_env) + assert result.returncode == 0 + + result = subprocess.run(["cmake", "--build", "."], cwd=tmp_path) + assert result.returncode == 0 + + # loading built .so in the separate process to avoid flooding current process + ver = f"{torchcodec.ffmpeg_major_version}" + result = subprocess.run( + [ + "python3", + "-c", + f"import torch; torch.ops.load_library('torchcodec_third_party_interface_test{ver}.so')", + ], + cwd=tmp_path, + ) + assert result.returncode == 0 diff --git a/test/utils.py b/test/utils.py index cbd6a5bf4..1d95c3d7c 100644 --- a/test/utils.py +++ b/test/utils.py @@ -2,22 +2,22 @@ import json import os import pathlib +import platform import subprocess import sys - from dataclasses import dataclass, field -from typing import Dict, List, Optional, Union import numpy as np import pytest - import torch +from torchcodec import ffmpeg_major_version from torchcodec._core import get_ffmpeg_library_versions from torchcodec.decoders import set_cuda_backend, VideoDecoder from torchcodec.decoders._video_decoder import _read_custom_frame_mappings IS_WINDOWS = sys.platform in ("win32", "cygwin") +IN_GITHUB_CI = bool(os.getenv("GITHUB_ACTIONS")) # Decorator for skipping CUDA tests when CUDA isn't available. The tests are @@ -27,6 +27,13 @@ def needs_cuda(test_item): return pytest.mark.needs_cuda(test_item) +# Decorator for skipping ffmpeg tests when ffmpeg cli isn't available. The tests are +# effectively marked to be skipped in pytest_collection_modifyitems() of +# conftest.py +def needs_ffmpeg_cli(test_item): + return pytest.mark.needs_ffmpeg_cli(test_item) + + # This is a special device string that we use to test the "beta" CUDA backend. # It only exists here, in this test utils file. Public and core APIs have no # idea that this is how we're tesing them. That is, that's not a supported @@ -45,6 +52,13 @@ def all_supported_devices(): ) +def cuda_devices(): + return ( + pytest.param("cuda", marks=pytest.mark.needs_cuda), + pytest.param(_CUDA_BETA_DEVICE_STR, marks=pytest.mark.needs_cuda), + ) + + def unsplit_device_str(device_str: str) -> str: # helper meant to be used as # device, device_variant = unsplit_device_str(device) @@ -76,28 +90,21 @@ def make_video_decoder(*args, **kwargs) -> tuple[VideoDecoder, str]: return dec, clean_device -def _get_ffmpeg_version_string(): +def get_ffmpeg_minor_version(): ffmpeg_version = get_ffmpeg_library_versions()["ffmpeg_version"] # When building FFmpeg from source there can be a `n` prefix in the version # string. This is quite brittle as we're using av_version_info(), which has # no stable format. See https://github.com/pytorch/torchcodec/issues/100 if ffmpeg_version.startswith("n"): ffmpeg_version = ffmpeg_version.removeprefix("n") - - return ffmpeg_version - - -def get_ffmpeg_major_version(): - ffmpeg_version = _get_ffmpeg_version_string() - return int(ffmpeg_version.split(".")[0]) + return int(ffmpeg_version.split(".")[1]) -def get_ffmpeg_minor_version(): - ffmpeg_version = _get_ffmpeg_version_string() - return int(ffmpeg_version.split(".")[1]) +def get_python_version() -> tuple[int, int]: + return (sys.version_info.major, sys.version_info.minor) -def cuda_version_used_for_building_torch() -> Optional[tuple[int, int]]: +def cuda_version_used_for_building_torch() -> tuple[int, int | None]: # Return the CUDA version that was used to build PyTorch. That's not always # the same as the CUDA version that is currently installed on the running # machine, which is what we actually want. On the CI though, these are the @@ -128,10 +135,10 @@ def psnr(a, b, max_val=255) -> float: # not guarantee bit-for-bit equality across systems and architectures, so we # also cannot. We currently use Linux on x86_64 as our reference system. def assert_frames_equal(*args, **kwargs): - if sys.platform == "linux": + if sys.platform == "linux" and "x86" in platform.machine().lower(): if args[0].device.type == "cuda": atol = 3 if cuda_version_used_for_building_torch() >= (13, 0) else 2 - if get_ffmpeg_major_version() == 4: + if ffmpeg_major_version == 4: assert_tensor_close_on_at_least( args[0], args[1], percentage=95, atol=atol ) @@ -140,6 +147,7 @@ def assert_frames_equal(*args, **kwargs): else: torch.testing.assert_close(*args, **kwargs, atol=0, rtol=0) else: + # Here: Windows, MacOS, and Linux for non-x86 architectures like aarch64 torch.testing.assert_close(*args, **kwargs, atol=3, rtol=0) @@ -219,10 +227,10 @@ class TestContainerFile: filename: str default_stream_index: int - stream_infos: Dict[int, Union[TestVideoStreamInfo, TestAudioStreamInfo]] - frames: Dict[int, Dict[int, TestFrameInfo]] - _custom_frame_mappings_data: Dict[ - int, Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] + stream_infos: dict[int, TestVideoStreamInfo | TestAudioStreamInfo] + frames: dict[int, dict[int, TestFrameInfo]] + _custom_frame_mappings_data: dict[ + int, tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] ] = field(default_factory=dict) def __post_init__(self): @@ -252,7 +260,7 @@ def __post_init__(self): "You need to submit this file, or specify the `frames` field manually." ) - with open(frames_info_path, "r") as f: + with open(frames_info_path) as f: frames_info = json.loads(f.read()) self.frames[stream_index] = { frame_index: TestFrameInfo( @@ -271,7 +279,7 @@ def to_tensor(self) -> torch.Tensor: return torch.from_numpy(arr) def get_frame_data_by_index( - self, idx: int, *, stream_index: Optional[int] = None + self, idx: int, *, stream_index: int | None = None ) -> torch.Tensor: raise NotImplementedError("Override in child classes") @@ -281,7 +289,7 @@ def get_frame_data_by_range( stop: int, step: int = 1, *, - stream_index: Optional[int] = None, + stream_index: int | None = None, ) -> torch.Tensor: raise NotImplementedError("Override in child classes") @@ -291,7 +299,7 @@ def get_pts_seconds_by_range( stop: int, step: int = 1, *, - stream_index: Optional[int] = None, + stream_index: int | None = None, ) -> torch.Tensor: if stream_index is None: stream_index = self.default_stream_index @@ -307,7 +315,7 @@ def get_duration_seconds_by_range( stop: int, step: int = 1, *, - stream_index: Optional[int] = None, + stream_index: int | None = None, ) -> torch.Tensor: if stream_index is None: stream_index = self.default_stream_index @@ -319,7 +327,7 @@ def get_duration_seconds_by_range( return torch.tensor(all_durations, dtype=torch.float64) def get_frame_info( - self, idx: int, *, stream_index: Optional[int] = None + self, idx: int, *, stream_index: int | None = None ) -> TestFrameInfo: if stream_index is None: stream_index = self.default_stream_index @@ -328,7 +336,7 @@ def get_frame_info( # This function is used to get the frame mappings for the custom_frame_mappings seek mode. def get_custom_frame_mappings( - self, stream_index: Optional[int] = None + self, stream_index: int | None = None ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if stream_index is None: stream_index = self.default_stream_index @@ -372,7 +380,7 @@ class TestVideo(TestContainerFile): """Base class for the *video* streams of a video container""" def get_base_path_by_index( - self, idx: int, *, stream_index: int, filters: Optional[str] = None + self, idx: int, *, stream_index: int, filters: str | None = None ) -> pathlib.Path: stream_and_frame = f"stream{stream_index}.frame{idx:06d}" if filters is not None: @@ -386,8 +394,8 @@ def get_frame_data_by_index( self, idx: int, *, - stream_index: Optional[int] = None, - filters: Optional[str] = None, + stream_index: int | None = None, + filters: str | None = None, ) -> torch.Tensor: if stream_index is None: stream_index = self.default_stream_index @@ -404,7 +412,7 @@ def get_frame_data_by_range( stop: int, step: int = 1, *, - stream_index: Optional[int] = None, + stream_index: int | None = None, ) -> torch.Tensor: tensors = [ self.get_frame_data_by_index(i, stream_index=stream_index) @@ -430,19 +438,19 @@ def empty_chw_tensor(self) -> torch.Tensor: [0, self.num_color_channels, self.height, self.width], dtype=torch.uint8 ) - def get_width(self, *, stream_index: Optional[int]) -> int: + def get_width(self, *, stream_index: int | None = None) -> int: if stream_index is None: stream_index = self.default_stream_index return self.stream_infos[stream_index].width - def get_height(self, *, stream_index: Optional[int] = None) -> int: + def get_height(self, *, stream_index: int | None = None) -> int: if stream_index is None: stream_index = self.default_stream_index return self.stream_infos[stream_index].height - def get_num_color_channels(self, *, stream_index: Optional[int] = None) -> int: + def get_num_color_channels(self, *, stream_index: int | None = None) -> int: if stream_index is None: stream_index = self.default_stream_index @@ -470,6 +478,18 @@ def get_empty_chw_tensor(self, *, stream_index: int) -> torch.Tensor: frames={}, # Automatically loaded from json file ) +NASA_VIDEO_ROTATED = TestVideo( + filename="nasa_13013_rotated.mp4", + default_stream_index=0, + stream_infos={ + # Post-rotation dimensions: 90-degree rotation swaps width/height + # This is a short video (~15 frames) extracted from nasa_13013.mp4 stream 3 + # with 90-degree rotation metadata added + 0: TestVideoStreamInfo(width=270, height=480, num_color_channels=3), + }, + frames={}, # Automatically loaded from json file +) + # Video generated with: # ffmpeg -f lavfi -i testsrc2=duration=1:size=200x200:rate=30 -c:v libx265 -pix_fmt yuv420p10le -preset fast -crf 23 h265_10bits.mp4 H265_10BITS = TestVideo( @@ -544,6 +564,68 @@ def get_empty_chw_tensor(self, *, stream_index: int) -> torch.Tensor: frames={0: {}}, # Not needed for now ) +# BT.2020 10-bit video with limited range (tv), generated with: +# ffmpeg -f lavfi -i testsrc2=duration=2:size=320x240:rate=30 -c:v libx265 \ +# -pix_fmt yuv420p10le -color_primaries bt2020 -color_trc smpte2084 \ +# -colorspace bt2020nc -color_range tv bt2020_10bit.mp4 +# +# Confirm color space with: +# ffprobe -v quiet -select_streams v:0 -show_entries stream=color_space,color_transfer,color_primaries,color_range -of default=noprint_wrappers=1 test/resources/bt2020_10bit.mp4 +# color_range=tv +# color_space=bt2020nc +# color_transfer=smpte2084 +# color_primaries=bt2020 +BT2020_LIMITED_RANGE_10BIT = TestVideo( + filename="bt2020_10bit.mp4", + default_stream_index=0, + stream_infos={ + 0: TestVideoStreamInfo(width=320, height=240, num_color_channels=3), + }, + frames={0: {}}, # Not needed for now +) + +# Full range BT.601 video, generated with: +# ffmpeg -f lavfi -i testsrc2=duration=2:size=320x240:rate=30 -c:v libx264 +# -profile:v high -pix_fmt yuv420p +# -vf "setparams=color_primaries=smpte170m:color_trc=smpte170m:colorspace=smpte170m:range=pc" +# bt601_full_range.mp4 +# +# Confirm color space with: +# ffprobe -v quiet -select_streams v:0 -show_entries stream=color_space,color_transfer,color_primaries,color_range -of default=noprint_wrappers=1 test/resources/bt601_full_range.mp4 +# color_range=pc +# color_space=smpte170m +# color_transfer=smpte170m +# color_primaries=smpte170m +BT601_FULL_RANGE = TestVideo( + filename="bt601_full_range.mp4", + default_stream_index=0, + stream_infos={ + 0: TestVideoStreamInfo(width=320, height=240, num_color_channels=3), + }, + frames={0: {}}, # Not needed for now +) + +# Limited range BT.601 video, generated with: +# ffmpeg -f lavfi -i testsrc2=duration=2:size=320x240:rate=30 -c:v libx264 +# -profile:v baseline -pix_fmt yuv420p +# -vf "setparams=color_primaries=smpte170m:color_trc=smpte170m:colorspace=smpte170m:range=tv" +# bt601_limited_range.mp4 +# +# Confirm color space with: +# ffprobe -v quiet -select_streams v:0 -show_entries stream=color_space,color_transfer,color_primaries,color_range -of default=noprint_wrappers=1 test/resources/bt601_limited_range.mp4 +# color_range=tv +# color_space=smpte170m +# color_transfer=smpte170m +# color_primaries=smpte170m +BT601_LIMITED_RANGE = TestVideo( + filename="bt601_limited_range.mp4", + default_stream_index=0, + stream_infos={ + 0: TestVideoStreamInfo(width=320, height=240, num_color_channels=3), + }, + frames={0: {}}, # Not needed for now +) + # ffmpeg -f lavfi -i testsrc2=duration=2:size=1280x720:rate=30 -c:v libx264 -profile:v baseline -level 3.1 -pix_fmt yuv420p -b:v 2500k -r 30 -movflags +faststart output_720p_2s.mp4 TEST_SRC_2_720P = TestVideo( filename="testsrc2.mp4", @@ -593,6 +675,18 @@ def get_empty_chw_tensor(self, *, stream_index: int) -> torch.Tensor: frames={0: {}}, # Not needed for now ) +# Video with non-zero start time (start_time ~8.333s) +# Used to test that PTS values are correctly reported for videos that don't +# start at time 0. +TEST_NON_ZERO_START = TestVideo( + filename="test_non_zero_start.mp4", + default_stream_index=0, + stream_infos={ + 0: TestVideoStreamInfo(width=200, height=112, num_color_channels=3), + }, + frames={}, # Automatically loaded from json file +) + def supports_approximate_mode(asset: TestVideo) -> bool: # Those are missing the `duration` field so they fail in approximate mode (on all devices). @@ -606,10 +700,10 @@ class TestAudio(TestContainerFile): """Base class for the *audio* streams of a container (potentially a video), or a pure audio file""" - stream_infos: Dict[int, TestAudioStreamInfo] + stream_infos: dict[int, TestAudioStreamInfo] # stream_index -> list of 2D frame tensors of shape (num_channels, num_samples_in_that_frame) # num_samples_in_that_frame isn't necessarily constant for a given stream. - _reference_frames: Dict[int, List[torch.Tensor]] = field(default_factory=dict) + _reference_frames: dict[int, list[torch.Tensor]] = field(default_factory=dict) # Storing each individual frame is too expensive for audio, because there's # a massive overhead in the binary format saved by pytorch. Saving all the @@ -633,7 +727,7 @@ def __post_init__(self): ) def get_frame_data_by_index( - self, idx: int, *, stream_index: Optional[int] = None + self, idx: int, *, stream_index: int | None = None ) -> torch.Tensor: if stream_index is None: stream_index = self.default_stream_index @@ -646,7 +740,7 @@ def get_frame_data_by_range( stop: int, step: int = 1, *, - stream_index: Optional[int] = None, + stream_index: int | None = None, ) -> torch.Tensor: tensors = [ self.get_frame_data_by_index(i, stream_index=stream_index) @@ -655,7 +749,7 @@ def get_frame_data_by_range( return torch.cat(tensors, dim=-1) def get_frame_index( - self, *, pts_seconds: float, stream_index: Optional[int] = None + self, *, pts_seconds: float, stream_index: int | None = None ) -> int: if stream_index is None: stream_index = self.default_stream_index @@ -814,3 +908,20 @@ def sample_format(self) -> str: ) }, ) + +# 16-channel audio for testing support for >8 channels. Generated with: +# ffmpeg -i test/resources/sine_mono_s32.wav -t 1 -filter_complex "[0]asplit=16[s0][s1][s2][s3][s4][s5][s6][s7][s8][s9][s10][s11][s12][s13][s14][s15];[s0][s1][s2][s3][s4][s5][s6][s7][s8][s9][s10][s11][s12][s13][s14][s15]amerge=inputs=16" -c:a pcm_s16le test/resources/sine_16ch_s16.wav +SINE_16_CHANNEL_S16 = TestAudio( + filename="sine_16ch_s16.wav", + default_stream_index=0, + frames={}, # Automatically loaded from json file + stream_infos={ + 0: TestAudioStreamInfo( + sample_rate=16_000, + num_channels=16, + duration_seconds=1, + num_frames=16, + sample_format="s16", + ) + }, +) diff --git a/test_paddle/test_video_decode.py b/test_paddle/test_video_decode.py index 4e41cb19e..1be61bd3e 100644 --- a/test_paddle/test_video_decode.py +++ b/test_paddle/test_video_decode.py @@ -169,8 +169,11 @@ def sample_indices_fn_func(metadata, **fn_kwargs): def test_video_decode(): - url = "https://paddlenlp.bj.bcebos.com/datasets/paddlemix/demo_video/example_video.mp4" - video, metadata = load_video(url, backend="torchcodec") + video_path = os.getenv( + "PADDLECODEC_TEST_VIDEO", + "https://paddlenlp.bj.bcebos.com/datasets/paddlemix/demo_video/example_video.mp4", + ) + video, metadata = load_video(video_path, backend="torchcodec") assert video.to(paddle.int64).sum().item() == 247759890390 assert metadata.total_num_frames == 263 assert metadata.fps == pytest.approx(29.99418249715141)