class ngraph::pass::low_precision::NetworkHelper¶
NetworkHelper class encapsulates manipulations with nGraph function.
#include <network_helper.hpp>
class NetworkHelper
{
public:
// classes
class InsertDequantizationResult;
// methods
static bool is_castable_to_one_of(
NodeTypeInfo type,
const std::unordered_set<NodeTypeInfo>& types
);
static std::vector<Input<Node>> consumer_inputs(std::shared_ptr<Node> node);
static std::vector<std::shared_ptr<Node>> consumers(std::shared_ptr<Node> node);
static bool isConstantPath(const std::shared_ptr<Node>& op);
template <typename OperationType>
static std::shared_ptr<Node> setOutDataPrecisionForTypeRelaxed(
std::shared_ptr<OperationType> operation,
const element::Type& precision
);
template <typename OperationType>
static std::shared_ptr<Node> setOutDataPrecision(
std::shared_ptr<OperationType> operation,
const element::Type& precision
);
static std::shared_ptr<opset1::Constant> foldDequantizationConstant(
const std::shared_ptr<opset1::Constant>& foldingConstant,
const std::shared_ptr<Node>& operation,
const size_t outIdx = 0
);
static size_t getOutputChannelsCount(
std::shared_ptr<const Node> layer,
bool isOnWeights = false
);
static std::vector<std::shared_ptr<Node>> getParentsRecursivelyExceptTypes(
std::shared_ptr<Node> layer,
const std::unordered_set<NodeTypeInfo>& exceptionLayerTypes = {},
const int portIndex = -1
);
static size_t getInputChannelsCount(std::shared_ptr<Node> layer);
static size_t getGroupsCount(std::shared_ptr<Node> layer);
static void removeLayer(std::shared_ptr<Node> node);
static std::shared_ptr<Node> swapMultiplyAndAdd(
std::shared_ptr<opset1::Add> addAfterMultiply,
const int multiplyBranch
);
static void copyInfo(
const std::vector<std::shared_ptr<Node>>& sources,
const std::vector<std::shared_ptr<Node>>& targets,
bool overrideName = true
);
static void copyInfo(
const std::vector<std::shared_ptr<Node>>& sources,
const std::shared_ptr<Node>& target,
bool overrideName = true
);
static void copyInfo(
const std::shared_ptr<Node>& source,
const std::shared_ptr<Node>& target,
bool overrideName = true
);
static bool isScalarLike(std::shared_ptr<opset1::Constant> constant);
static bool isZero(std::shared_ptr<opset1::Constant> constant);
static std::shared_ptr<opset1::Constant> toScalar(std::shared_ptr<opset1::Constant> constant);
static std::shared_ptr<Node> getConstantInput(
const std::shared_ptr<const Node>& node,
const bool convertIsExpected = false
);
static std::vector<size_t> updateReshapeValues(
const Shape& elementwiseConstantShape,
const Shape& elementwiseShape,
const std::vector<size_t>& reshapeValues
);
static std::shared_ptr<ngraph::opset1::Multiply> optimizeMultipliesAfter(std::shared_ptr<Node> multiply);
static std::shared_ptr<opset1::Constant> round(
std::shared_ptr<Node> node,
element::Type target_type
);
static std::shared_ptr<opset1::FakeQuantize> composeFakeQuantize(
const std::shared_ptr<opset1::FakeQuantize>& fq,
const std::vector<ngraph::element::Type>& defaultPrecisions = precision_set::int8_support
);
static std::tuple<std::shared_ptr<Node>, std::shared_ptr<Node>> decomposeFakeQuantize(
std::shared_ptr<opset1::FakeQuantize> fq,
const element::Type precision,
const float min,
const float max,
const bool hasZeroPoint,
const bool updatePrecision,
const element::Type deqPrecision = element::f32,
const size_t outChannelsShapeIndex = 0
);
static std::shared_ptr<opset1::FakeQuantize> updateFakeQuantize(
std::shared_ptr<opset1::FakeQuantize> fq,
element::Type precision,
float min,
float max,
const bool replace = true
);
static FakeQuantizeDequantization makeDequantization(
const float dequantizationMul,
const float dequantizationSub,
const ngraph::element::Type originalPrecision,
const ngraph::PartialShape& dataNodeOutputShape,
element::Type precision,
const element::Type deqPrecision = element::f32,
std::shared_ptr<ngraph::Node> input = nullptr
);
static FakeQuantizeDequantization createDequantizationFromFakeQuantize(
std::shared_ptr<opset1::FakeQuantize> fq,
element::Type precision,
float min,
float max,
const bool hasZeroPoint,
const bool updatePrecision,
const element::Type deqPrecision = element::f32
);
static bool areQuantizeAndDequantizeSupportedForSubtract(
const std::shared_ptr<const ngraph::Node>& node,
const std::vector<ngraph::element::Type>& defaultPrecisions = precision_set::int8_support
);
static bool areQuantizeAndDequantizeSupportedForMultiply(
const std::shared_ptr<const ngraph::Node>& node,
const std::vector<ngraph::element::Type>& _defaultPrecisions = precision_set::int8_support
);
static bool isQuantizeSupported(const std::shared_ptr<opset1::FakeQuantize>& fakeQuantize);
static FakeQuantizeDequantization getDequantization(
const std::shared_ptr<const Node>& node,
const std::vector<ngraph::element::Type> _defaultPrecisions = precision_set::int8_support,
const size_t parentIndex = 0ul,
const bool inPlace = false
);
static FakeQuantizeDequantization getDequantizationBelow(
const std::shared_ptr<Node>& node,
const bool convertIsMandatory = false
);
static FakeQuantizeDequantization normalizeDequantization(FakeQuantizeDequantization dequantization);
static std::shared_ptr<opset1::Constant> normalizeDequantizationShape(
const std::shared_ptr<Node>& eltwise,
const bool convertIsExpected = false
);
static std::shared_ptr<Node> optimizeSubtract(std::shared_ptr<opset1::Subtract> add);
static InsertDequantizationResult moveDequantizationAfter(
const std::shared_ptr<ngraph::Node>& operation,
const FakeQuantizeDequantization& dequantization,
const bool updatePrecision,
const bool moveSubtract,
const std::vector<ngraph::element::Type>& defaultPrecisions = precision_set::int8_support
);
static InsertDequantizationResult moveDequantizationBefore(
const std::shared_ptr<ngraph::Node>& operation,
const FakeQuantizeDequantization& dequantization,
const bool updatePrecision,
const bool moveSubtract
);
static std::vector<std::vector<std::shared_ptr<ngraph::opset1::Constant>>> splitConstantsBeforeConcat(
const std::shared_ptr<ov::Node> concat,
const std::vector<std::shared_ptr<opset1::Constant>> currConstants
);
static bool checkConstantValuePrecision(
const element::Type expectedPrecision,
const std::shared_ptr<Node>& constant
);
static size_t getChildInputIndex(
const std::shared_ptr<ngraph::Node>& parent,
const std::shared_ptr<ngraph::Node>& child
);
static size_t getParentOutputIndex(
const std::shared_ptr<ngraph::Node>& parent,
const std::shared_ptr<ngraph::Node>& child
);
static FakeQuantizeDequantizationValues createEmptyValues(
const FakeQuantizeDequantization& dequantization,
const element::Type precision
);
static bool isZeroConst(const std::shared_ptr<Node>& node);
static bool checkZeroPoint(
const std::shared_ptr<Node>& node,
const DataPrecision& dataPrecision = DataPrecision()
);
static std::shared_ptr<Node> toScalarIfPossible(std::shared_ptr<Node> node);
static std::shared_ptr<Node> fold_fake_quantize(const std::shared_ptr<opset1::FakeQuantize>& fq);
static std::shared_ptr<Node> fold_fake_quantize(
const std::shared_ptr<opset1::FakeQuantize>& fq,
const bool roundValues,
int outChannelsShapeIndex = 0
);
static FakeQuantizeDequantization foldDequantization(
const std::shared_ptr<Node>& node,
const size_t branchIndex,
const std::vector<ngraph::element::Type>& defaultPrecisions = precision_set::int8_support,
const bool inPlace = false
);
static std::shared_ptr<ngraph::Node> separateInStandaloneBranch(
std::shared_ptr<ngraph::Node> node,
const std::vector<ngraph::element::Type>& defaultPrecisions = precision_set::int8_support
);
static std::shared_ptr<opset1::FakeQuantize> fuseConvert(const std::shared_ptr<opset1::FakeQuantize>& fakeQuantize);
static std::vector<element::Type> precisionIntersection(
const std::vector<element::Type>& v1,
const std::vector<element::Type>& v2
);
static bool isPrecisionPreserved(const std::shared_ptr<ngraph::Node>& node);
static void insertDequantizationAfter(
const std::shared_ptr<Node>& originalNode,
const std::shared_ptr<Node>& dequantization,
const std::shared_ptr<Node>& newNode
);
template <typename SharedAttribute>
static void reassign(
const std::shared_ptr<typename SharedAttribute::SharedValueAttribute::SharedValue>& sharedValue,
const std::vector<std::weak_ptr<typename SharedAttribute::SharedValueAttribute>>& attributes
);
static size_t calculateLevels(
const float dataPrecisionMin,
const float dataPrecisionMax,
const float combinedIntervalLow,
const float combinedIntervalHigh,
const float minIntervalLow,
const float minIntervalHigh,
float& dequantizationMul,
float& dequantizationSub,
float& updatedOutputLowValue,
float& updatedOutputHighValue
);
};