class ov::pass::PassConfig¶
Overview¶
Class representing a transformations config that is used for disabling/enabling transformations registered inside pass::Manager and also allows to set callback for all transformations or for particular transformation. More…
#include <pass_config.hpp>
class PassConfig
{
public:
// methods
void disable(const DiscreteTypeInfo& type_info);
template <
class T,
typename std::enable_if<!ngraph::HasTypeInfoMember<T>::value, bool>::type = true
>
void disable();
void enable(const DiscreteTypeInfo& type_info);
template <
class T,
typename std::enable_if<!ngraph::HasTypeInfoMember<T>::value, bool>::type = true
>
void enable();
void set_callback(const param_callback& callback);
template <typename... Args>
std::enable_if<sizeof...(Args)==0>::type set_callback(const param_callback& callback);
template <
typename T,
class... Args,
typename std::enable_if<!ngraph::HasTypeInfoMember<T>::value, bool>::type = true
>
void set_callback(const param_callback& callback);
param_callback get_callback(const DiscreteTypeInfo& type_info) const;
template <
class T,
typename std::enable_if<ngraph::HasTypeInfoMember<T>::value, bool>::type = true
>
param_callback get_callback() const;
bool is_disabled(const DiscreteTypeInfo& type_info) const;
template <
class T,
typename std::enable_if<ngraph::HasTypeInfoMember<T>::value, bool>::type = true
>
bool is_disabled() const;
bool is_enabled(const DiscreteTypeInfo& type_info) const;
template <
class T,
typename std::enable_if<ngraph::HasTypeInfoMember<T>::value, bool>::type = true
>
bool is_enabled() const;
void add_disabled_passes(const PassConfig& rhs);
};
Detailed Documentation¶
Class representing a transformations config that is used for disabling/enabling transformations registered inside pass::Manager and also allows to set callback for all transformations or for particular transformation.
When pass::Manager is created all passes registered inside this manager including nested passes will share the same instance of PassConfig class. To work with this class first you need to get shared instance of this class by calling manager.get_pass_config() method. Then you will be able to disable/enable passes based on transformations type_info. For example:
pass::Manager manager;
manager.register_pass<CommonOptimizations>();
auto pass_config = manager.get_pass_config();
pass_config->disable<ConvertGELU>(); // this will disable nested pass inside
// CommonOptimizations pipeline
manager.run_passes(f);
Sometimes it is needed to call transformation inside other transformation manually. And for that case before running transformation you need manually check that this pass is not disabled and then you need to set current PassConfig instance to this transformation. For example:
// Inside MatcherPass callback or inside FunctionPass run_on_function() method
// you need to call get_pass_config() method to get shared instance of PassConfig
auto pass_config = get_pass_config();
// Before running nested transformation you need to check is it disabled or not
if (!pass_config->is_disabled<ConvertGELU>()) {
auto pass = ConvertGELU();
pass->set_pass_config(pass_config);
pass.apply(node);
}
Following this logic inside your transformations you will guaranty that transformations will be executed in a right way.
Methods¶
void disable(const DiscreteTypeInfo& type_info)
Disable transformation by its type_info.
Parameters:
type_info |
Transformation type_info |
template <
class T,
typename std::enable_if<!ngraph::HasTypeInfoMember<T>::value, bool>::type = true
>
void disable()
Disable transformation by its class type (based on type_info)
void enable(const DiscreteTypeInfo& type_info)
Enable transformation by its type_info.
Parameters:
type_info |
Transformation type_info |
template <
class T,
typename std::enable_if<!ngraph::HasTypeInfoMember<T>::value, bool>::type = true
>
void enable()
Enable transformation by its class type (based on type_info)
void set_callback(const param_callback& callback)
Set callback for all kind of transformations.
template <
typename T,
class... Args,
typename std::enable_if<!ngraph::HasTypeInfoMember<T>::value, bool>::type = true
>
void set_callback(const param_callback& callback)
Set callback for particular transformation class types.
Example below show how to set callback for one or multiple passes using this method.
pass_config->set_callback<ov::pass::ConvertBatchToSpace,
ov::pass::ConvertSpaceToBatch>(
[](const_node_ptr &node) -> bool {
// Disable transformations for cases when input shape rank is not
equal to 4
const auto input_shape_rank =
node->get_output_partial_shape(0).rank().get_length();
if (input_shape_rank != 4) {
return false;
}
return true;
});
Note that inside transformations you must provide code that work with this callback. See example below:
if (transformation_callback(node)) {
return false; // exit from transformation
}
param_callback get_callback(const DiscreteTypeInfo& type_info) const
Get callback for given transformation type_info.
In case if callback wasn’t set for given transformation type then global callback will be returned. But if even global callback wasn’t set then default callback will be returned.
Parameters:
type_info |
Transformation type_info |
template <
class T,
typename std::enable_if<ngraph::HasTypeInfoMember<T>::value, bool>::type = true
>
param_callback get_callback() const
Get callback for given transformation class type.
Returns:
callback lambda function
bool is_disabled(const DiscreteTypeInfo& type_info) const
Check either transformation type is disabled or not.
Parameters:
type_info |
Transformation type_info |
Returns:
true if transformation type was disabled and false otherwise
template <
class T,
typename std::enable_if<ngraph::HasTypeInfoMember<T>::value, bool>::type = true
>
bool is_disabled() const
Check either transformation class type is disabled or not.
Returns:
true if transformation type was disabled and false otherwise
bool is_enabled(const DiscreteTypeInfo& type_info) const
Check either transformation type is force enabled or not.
Parameters:
type_info |
Transformation type_info |
Returns:
true if transformation type was force enabled and false otherwise
template <
class T,
typename std::enable_if<ngraph::HasTypeInfoMember<T>::value, bool>::type = true
>
bool is_enabled() const
Check either transformation class type is force enabled or not.
Returns:
true if transformation type was force enabled and false otherwise