Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitbcde35e

Browse files
committed
Implement initializeFiltersContext for CPU device interface
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
1 parentebcb48d commitbcde35e

File tree

2 files changed

+108
-92
lines changed

2 files changed

+108
-92
lines changed

‎src/torchcodec/_core/CpuDeviceInterface.cpp‎

Lines changed: 96 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,35 @@ static bool g_cpu = registerDeviceInterface(
1313
torch::kCPU,
1414
[](const torch::Device& device) {returnnewCpuDeviceInterface(device); });
1515

16+
ColorConversionLibrarygetColorConversionLibrary(
17+
const VideoStreamOptions& videoStreamOptions,
18+
int width) {
19+
// By default, we want to use swscale for color conversion because it is
20+
// faster. However, it has width requirements, so we may need to fall back
21+
// to filtergraph. We also need to respect what was requested from the
22+
// options; we respect the options unconditionally, so it's possible for
23+
// swscale's width requirements to be violated. We don't expose the ability to
24+
// choose color conversion library publicly; we only use this ability
25+
// internally.
26+
27+
// swscale requires widths to be multiples of 32:
28+
// https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
29+
// so we fall back to filtergraph if the width is not a multiple of 32.
30+
auto defaultLibrary = (width %32 ==0)
31+
? ColorConversionLibrary::SWSCALE
32+
: ColorConversionLibrary::FILTERGRAPH;
33+
34+
ColorConversionLibrary colorConversionLibrary =
35+
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);
36+
37+
TORCH_CHECK(
38+
colorConversionLibrary == ColorConversionLibrary::SWSCALE ||
39+
colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH,
40+
"Invalid color conversion library:",
41+
static_cast<int>(colorConversionLibrary));
42+
return colorConversionLibrary;
43+
}
44+
1645
}// namespace
1746

1847
CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
@@ -22,6 +51,52 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
2251
device_.type() == torch::kCPU,"Unsupported device:", device_.str());
2352
}
2453

54+
std::unique_ptr<FiltersContext>CpuDeviceInterface::initializeFiltersContextInternal(
55+
const VideoStreamOptions& videoStreamOptions,
56+
const UniqueAVFrame& avFrame,
57+
const AVRational& timeBase) {
58+
enum AVPixelFormat frameFormat =
59+
static_cast<enum AVPixelFormat>(avFrame->format);
60+
auto frameDims =
61+
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame);
62+
int expectedOutputHeight = frameDims.height;
63+
int expectedOutputWidth = frameDims.width;
64+
65+
std::unique_ptr<FiltersContext> filtersContext =
66+
std::make_unique<FiltersContext>();
67+
68+
filtersContext->inputWidth = avFrame->width;
69+
filtersContext->inputHeight = avFrame->height;
70+
filtersContext->inputFormat = frameFormat;
71+
filtersContext->inputAspectRatio = avFrame->sample_aspect_ratio;
72+
filtersContext->outputWidth = expectedOutputWidth;
73+
filtersContext->outputHeight = expectedOutputHeight;
74+
filtersContext->outputFormat = AV_PIX_FMT_RGB24;
75+
filtersContext->timeBase = timeBase;
76+
77+
std::stringstream filters;
78+
filters <<"scale=" << expectedOutputWidth <<":" << expectedOutputHeight;
79+
filters <<":sws_flags=bilinear";
80+
81+
filtersContext->filters = filters.str();
82+
return filtersContext;
83+
}
84+
85+
std::unique_ptr<FiltersContext>CpuDeviceInterface::initializeFiltersContext(
86+
const VideoStreamOptions& videoStreamOptions,
87+
const UniqueAVFrame& avFrame,
88+
const AVRational& timeBase) {
89+
auto frameDims =
90+
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame);
91+
int expectedOutputWidth = frameDims.width;
92+
93+
if (getColorConversionLibrary(videoStreamOptions, expectedOutputWidth) == ColorConversionLibrary::SWSCALE) {
94+
returnnullptr;
95+
}
96+
97+
returninitializeFiltersContextInternal(videoStreamOptions, avFrame, timeBase);
98+
}
99+
25100
// Note [preAllocatedOutputTensor with swscale and filtergraph]:
26101
// Callers may pass a pre-allocated tensor, where the output.data tensor will
27102
// be stored. This parameter is honored in any case, but it only leads to a
@@ -56,56 +131,25 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
56131
}
57132

58133
torch::Tensor outputTensor;
59-
// We need to compare the current frame context with our previous frame
60-
// context. If they are different, then we need to re-create our colorspace
61-
// conversion objects. We create our colorspace conversion objects late so
62-
// that we don't have to depend on the unreliable metadata in the header.
63-
// And we sometimes re-create them because it's possible for frame
64-
// resolution to change mid-stream. Finally, we want to reuse the colorspace
65-
// conversion objects as much as possible for performance reasons.
66-
enum AVPixelFormat frameFormat =
67-
static_cast<enum AVPixelFormat>(avFrame->format);
68-
FiltersContext filtersContext;
69-
70-
filtersContext.inputWidth = avFrame->width;
71-
filtersContext.inputHeight = avFrame->height;
72-
filtersContext.inputFormat = frameFormat;
73-
filtersContext.inputAspectRatio = avFrame->sample_aspect_ratio;
74-
filtersContext.outputWidth = expectedOutputWidth;
75-
filtersContext.outputHeight = expectedOutputHeight;
76-
filtersContext.outputFormat = AV_PIX_FMT_RGB24;
77-
filtersContext.timeBase = timeBase;
78-
79-
std::stringstream filters;
80-
filters <<"scale=" << expectedOutputWidth <<":" << expectedOutputHeight;
81-
filters <<":sws_flags=bilinear";
82-
83-
filtersContext.filters = filters.str();
84-
85-
// By default, we want to use swscale for color conversion because it is
86-
// faster. However, it has width requirements, so we may need to fall back
87-
// to filtergraph. We also need to respect what was requested from the
88-
// options; we respect the options unconditionally, so it's possible for
89-
// swscale's width requirements to be violated. We don't expose the ability to
90-
// choose color conversion library publicly; we only use this ability
91-
// internally.
92-
93-
// swscale requires widths to be multiples of 32:
94-
// https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
95-
// so we fall back to filtergraph if the width is not a multiple of 32.
96-
auto defaultLibrary = (expectedOutputWidth %32 ==0)
97-
? ColorConversionLibrary::SWSCALE
98-
: ColorConversionLibrary::FILTERGRAPH;
99-
100134
ColorConversionLibrary colorConversionLibrary =
101-
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);
135+
getColorConversionLibrary(videoStreamOptions, expectedOutputWidth);
102136

103137
if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
104138
outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
105139
expectedOutputHeight, expectedOutputWidth, torch::kCPU));
106140

141+
// We need to compare the current frame context with our previous frame
142+
// context. If they are different, then we need to re-create our colorspace
143+
// conversion objects. We create our colorspace conversion objects late so
144+
// that we don't have to depend on the unreliable metadata in the header.
145+
// And we sometimes re-create them because it's possible for frame
146+
// resolution to change mid-stream. Finally, we want to reuse the colorspace
147+
// conversion objects as much as possible for performance reasons.
148+
std::unique_ptr<FiltersContext> filtersContext =
149+
initializeFiltersContextInternal(videoStreamOptions, avFrame, timeBase);
150+
107151
if (!swsContext_ || prevFiltersContext_ != filtersContext) {
108-
createSwsContext(filtersContext, avFrame->colorspace);
152+
createSwsContext(*filtersContext, avFrame->colorspace);
109153
prevFiltersContext_ =std::move(filtersContext);
110154
}
111155
int resultHeight =
@@ -122,25 +166,16 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
122166

123167
frameOutput.data = outputTensor;
124168
}elseif (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
125-
if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) {
126-
filterGraphContext_ =
127-
std::make_unique<FilterGraph>(filtersContext, videoStreamOptions);
128-
prevFiltersContext_ =std::move(filtersContext);
129-
}
130-
outputTensor =convertAVFrameToTensorUsingFilterGraph(avFrame);
169+
TORCH_CHECK_EQ(avFrame->format, AV_PIX_FMT_RGB24);
131170

132-
// Similarly to above, if this check fails it means the frame wasn't
133-
// reshaped to its expected dimensions by filtergraph.
134-
auto shape = outputTensor.sizes();
135-
TORCH_CHECK(
136-
(shape.size() ==3) && (shape[0] == expectedOutputHeight) &&
137-
(shape[1] == expectedOutputWidth) && (shape[2] ==3),
138-
"Expected output tensor of shape",
139-
expectedOutputHeight,
140-
"x",
141-
expectedOutputWidth,
142-
"x3, got",
143-
shape);
171+
std::vector<int64_t> shape = {expectedOutputHeight, expectedOutputWidth,3};
172+
std::vector<int64_t> strides = {avFrame->linesize[0],3,1};
173+
AVFrame* avFramePtr = avFrame.release();
174+
auto deleter = [avFramePtr](void*) {
175+
UniqueAVFrameavFrameToDelete(avFramePtr);
176+
};
177+
outputTensor =torch::from_blob(
178+
avFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
144179

145180
if (preAllocatedOutputTensor.has_value()) {
146181
// We have already validated that preAllocatedOutputTensor and
@@ -150,11 +185,6 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
150185
}else {
151186
frameOutput.data = outputTensor;
152187
}
153-
}else {
154-
TORCH_CHECK(
155-
false,
156-
"Invalid color conversion library:",
157-
static_cast<int>(colorConversionLibrary));
158188
}
159189
}
160190

@@ -176,25 +206,6 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
176206
return resultHeight;
177207
}
178208

179-
torch::TensorCpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
180-
const UniqueAVFrame& avFrame) {
181-
UniqueAVFrame filteredAVFrame = filterGraphContext_->convert(avFrame);
182-
183-
TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24);
184-
185-
auto frameDims =getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get());
186-
int height = frameDims.height;
187-
int width = frameDims.width;
188-
std::vector<int64_t> shape = {height, width,3};
189-
std::vector<int64_t> strides = {filteredAVFrame->linesize[0],3,1};
190-
AVFrame* filteredAVFramePtr = filteredAVFrame.release();
191-
auto deleter = [filteredAVFramePtr](void*) {
192-
UniqueAVFrameavFrameToDelete(filteredAVFramePtr);
193-
};
194-
returntorch::from_blob(
195-
filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
196-
}
197-
198209
voidCpuDeviceInterface::createSwsContext(
199210
const FiltersContext& filtersContext,
200211
constenum AVColorSpace colorspace) {

‎src/torchcodec/_core/CpuDeviceInterface.h‎

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ class CpuDeviceInterface : public DeviceInterface {
2626
voidinitializeContext(
2727
[[maybe_unused]] AVCodecContext* codecContext)override {}
2828

29+
std::unique_ptr<FiltersContext>initializeFiltersContext(
30+
const VideoStreamOptions& videoStreamOptions,
31+
const UniqueAVFrame& avFrame,
32+
const AVRational& timeBase)override;
33+
2934
voidconvertAVFrameToFrameOutput(
3035
const VideoStreamOptions& videoStreamOptions,
3136
const AVRational& timeBase,
@@ -39,21 +44,21 @@ class CpuDeviceInterface : public DeviceInterface {
3944
const UniqueAVFrame& avFrame,
4045
torch::Tensor& outputTensor);
4146

42-
torch::TensorconvertAVFrameToTensorUsingFilterGraph(
43-
const UniqueAVFrame& avFrame);
47+
std::unique_ptr<FiltersContext>initializeFiltersContextInternal(
48+
const VideoStreamOptions& videoStreamOptions,
49+
const UniqueAVFrame& avFrame,
50+
const AVRational& timeBase);
4451

4552
voidcreateSwsContext(
4653
const FiltersContext& filtersContext,
4754
constenum AVColorSpace colorspace);
4855

49-
// color-conversion fields. Only one of FilterGraphContext and
50-
// UniqueSwsContext should be non-null.
51-
std::unique_ptr<FilterGraph> filterGraphContext_;
56+
// SWS color conversion context
5257
UniqueSwsContext swsContext_;
5358

54-
// Used to know whether a newFilterGraphContext orUniqueSwsContext should
59+
// Used to know whether a new UniqueSwsContext should
5560
// be created before decoding a new frame.
56-
FiltersContext prevFiltersContext_;
61+
std::unique_ptr<FiltersContext> prevFiltersContext_;
5762
};
5863

5964
}// namespace facebook::torchcodec

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp