Rate this Page
★★★★★
Program Listing for File custom_class.h#
↰Return to documentation for file (torch/custom_class.h)
#pragma once#include<ATen/core/builtin_function.h>#include<ATen/core/function_schema.h>#include<ATen/core/ivalue.h>#include<ATen/core/class_type.h>#include<ATen/core/op_registration/infer_schema.h>#include<ATen/core/stack.h>#include<c10/util/C++17.h>#include<c10/util/Metaprogramming.h>#include<c10/util/TypeList.h>#include<c10/util/TypeTraits.h>#include<torch/custom_class_detail.h>#include<torch/library.h>#include<functional>#include<sstream>namespacetorch{template<class...Types>detail::types<void,Types...>init(){returndetail::types<void,Types...>{};}template<typenameFunc,typename...ParameterTypeList>structInitLambda{Funcf;};template<typenameFunc>decltype(auto)init(Func&&f){usingInitTraits=c10::guts::infer_function_traits_t<std::decay_t<Func>>;usingParameterTypeList=typenameInitTraits::parameter_types;InitLambda<Func,ParameterTypeList>init{std::forward<Func>(f)};returninit;}template<classCurClass>classclass_:public::torch::detail::class_base{static_assert(std::is_base_of_v<CustomClassHolder,CurClass>,"torch::class_<T> requires T to inherit from CustomClassHolder");public:explicitclass_(conststd::string&namespaceName,conststd::string&className,std::stringdoc_string=""):class_base(namespaceName,className,std::move(doc_string),typeid(c10::intrusive_ptr<CurClass>),typeid(c10::tagged_capsule<CurClass>)){}template<typename...Types>class_&def(torch::detail::types<void,Types...>/*unused*/,std::stringdoc_string="",std::initializer_list<arg>default_args={}){// Used in combination with// torch::init<...>()autofunc=[](c10::tagged_capsule<CurClass>self,Types...args){autoclassObj=c10::make_intrusive<CurClass>(args...);autoobject=self.ivalue.toObject();object->setSlot(0,c10::IValue::make_capsule(std::move(classObj)));};defineMethod("__init__",std::move(func),std::move(doc_string),default_args);return*this;}// Used in combination with torch::init([]lambda(){......})template<typenameFunc,typename...ParameterTypes>class_&def(InitLambda<Func,c10::guts::typelist::typelist<ParameterTypes...>>init,std::stringdoc_string="",std::initializer_list<arg>default_args={}){autoinit_lambda_wrapper=[func=std::move(init.f)](c10::tagged_capsule<CurClass>self,ParameterTypes...arg){c10::intrusive_ptr<CurClass>classObj=std::invoke(func,std::forward<ParameterTypes>(arg)...);autoobject=self.ivalue.toObject();object->setSlot(0,c10::IValue::make_capsule(classObj));};defineMethod("__init__",std::move(init_lambda_wrapper),std::move(doc_string),default_args);return*this;}template<typenameFunc>class_&def(std::stringname,Funcf,std::stringdoc_string="",std::initializer_list<arg>default_args={}){autowrapped_f=detail::wrap_func<CurClass,Func>(std::move(f));defineMethod(std::move(name),std::move(wrapped_f),std::move(doc_string),default_args);return*this;}template<typenameFunc>class_&def_static(std::stringname,Funcfunc,std::stringdoc_string=""){autoqualMethodName=qualClassName+"."+name;autoschema=c10::inferFunctionSchemaSingleReturn<Func>(std::move(name),"");autowrapped_func=[func=std::move(func)](jit::Stack&stack)mutable->void{usingRetType=typenamec10::guts::infer_function_traits_t<Func>::return_type;detail::BoxedProxy<RetType,Func>()(stack,func);};automethod=std::make_unique<jit::BuiltinOpFunction>(std::move(qualMethodName),std::move(schema),std::move(wrapped_func),std::move(doc_string));classTypePtr->addStaticMethod(method.get());registerCustomClassMethod(std::move(method));return*this;}template<typenameGetterFunc,typenameSetterFunc>class_&def_property(conststd::string&name,GetterFuncgetter_func,SetterFuncsetter_func,std::stringdoc_string=""){torch::jit::Function*getter{};torch::jit::Function*setter{};autowrapped_getter=detail::wrap_func<CurClass,GetterFunc>(std::move(getter_func));getter=defineMethod(name+"_getter",wrapped_getter,doc_string);autowrapped_setter=detail::wrap_func<CurClass,SetterFunc>(std::move(setter_func));setter=defineMethod(name+"_setter",wrapped_setter,doc_string);classTypePtr->addProperty(name,getter,setter);return*this;}template<typenameGetterFunc>class_&def_property(conststd::string&name,GetterFuncgetter_func,std::stringdoc_string=""){torch::jit::Function*getter{};autowrapped_getter=detail::wrap_func<CurClass,GetterFunc>(std::move(getter_func));getter=defineMethod(name+"_getter",wrapped_getter,doc_string);classTypePtr->addProperty(name,getter,nullptr);return*this;}template<typenameT>class_&def_readwrite(conststd::string&name,TCurClass::*field){autogetter_func=[field=field](constc10::intrusive_ptr<CurClass>&self){returnself.get()->*field;};autosetter_func=[field=field](constc10::intrusive_ptr<CurClass>&self,Tvalue){self.get()->*field=value;};returndef_property(name,getter_func,setter_func);}template<typenameT>class_&def_readonly(conststd::string&name,TCurClass::*field){autogetter_func=[field=std::move(field)](constc10::intrusive_ptr<CurClass>&self){returnself.get()->*field;};returndef_property(name,getter_func);}class_&_def_unboxed(conststd::string&name,std::function<void(jit::Stack&)>func,c10::FunctionSchemaschema,std::stringdoc_string=""){automethod=std::make_unique<jit::BuiltinOpFunction>(qualClassName+"."+name,std::move(schema),std::move(func),std::move(doc_string));classTypePtr->addMethod(method.get());registerCustomClassMethod(std::move(method));return*this;}template<typenameGetStateFn,typenameSetStateFn>class_&def_pickle(GetStateFn&&get_state,SetStateFn&&set_state){static_assert(c10::guts::is_stateless_lambda<std::decay_t<GetStateFn>>::value&&c10::guts::is_stateless_lambda<std::decay_t<SetStateFn>>::value,"def_pickle() currently only supports lambdas as ""__getstate__ and __setstate__ arguments.");def("__getstate__",std::forward<GetStateFn>(get_state));// __setstate__ needs to be registered with some custom handling:// We need to wrap the invocation of the user-provided function// such that we take the return value (i.e. c10::intrusive_ptr<CurrClass>)// and assign it to the `capsule` attribute.usingSetStateTraits=c10::guts::infer_function_traits_t<std::decay_t<SetStateFn>>;usingSetStateArg=typenamec10::guts::typelist::head_t<typenameSetStateTraits::parameter_types>;autosetstate_wrapper=[set_state=std::forward<SetStateFn>(set_state)](c10::tagged_capsule<CurClass>self,SetStateArgarg){c10::intrusive_ptr<CurClass>classObj=std::invoke(set_state,std::move(arg));autoobject=self.ivalue.toObject();object->setSlot(0,c10::IValue::make_capsule(classObj));};defineMethod("__setstate__",detail::wrap_func<CurClass,decltype(setstate_wrapper)>(std::move(setstate_wrapper)));// type validationautogetstate_schema=classTypePtr->getMethod("__getstate__").getSchema();#ifndef STRIP_ERROR_MESSAGESautoformat_getstate_schema=[&getstate_schema](){std::stringstreamss;ss<<getstate_schema;returnss.str();};#endifTORCH_CHECK(getstate_schema.arguments().size()==1,"__getstate__ should take exactly one argument: self. Got: ",format_getstate_schema());autofirst_arg_type=getstate_schema.arguments().at(0).type();TORCH_CHECK(*first_arg_type==*classTypePtr,"self argument of __getstate__ must be the custom class type. Got ",first_arg_type->repr_str());TORCH_CHECK(getstate_schema.returns().size()==1,"__getstate__ should return exactly one value for serialization. Got: ",format_getstate_schema());autoser_type=getstate_schema.returns().at(0).type();autosetstate_schema=classTypePtr->getMethod("__setstate__").getSchema();autoarg_type=setstate_schema.arguments().at(1).type();TORCH_CHECK(ser_type->isSubtypeOf(*arg_type),"__getstate__'s return type should be a subtype of ""input argument of __setstate__. Got ",ser_type->repr_str()," but expected ",arg_type->repr_str());return*this;}private:template<typenameFunc>torch::jit::Function*defineMethod(std::stringname,Funcfunc,std::stringdoc_string="",std::initializer_list<arg>default_args={}){autoqualMethodName=qualClassName+"."+name;autoschema=c10::inferFunctionSchemaSingleReturn<Func>(std::move(name),"");// If default values are provided for function arguments, there must be// none (no default values) or default values for all function// arguments, except for self. This is because argument names are not// extracted by inferFunctionSchemaSingleReturn, and so there must be a// torch::arg instance in default_args even for arguments that do not// have an actual default value provided.TORCH_CHECK(default_args.size()==0||default_args.size()==schema.arguments().size()-1,"Default values must be specified for none or all arguments");// If there are default args, copy the argument names and default values to// the function schema.if(default_args.size()>0){schema=withNewArguments(schema,default_args);}autowrapped_func=[func=std::move(func)](jit::Stack&stack)mutable->void{// TODO: we need to figure out how to profile calls to custom functions// like this! Currently can't do it because the profiler stuff is in// libtorch and not ATenusingRetType=typenamec10::guts::infer_function_traits_t<Func>::return_type;detail::BoxedProxy<RetType,Func>()(stack,func);};automethod=std::make_unique<jit::BuiltinOpFunction>(qualMethodName,std::move(schema),std::move(wrapped_func),std::move(doc_string));// Register the method here to keep the Method alive.// ClassTypes do not hold ownership of their methods (normally it// those are held by the CompilationUnit), so we need a proxy for// that behavior here.automethod_val=method.get();classTypePtr->addMethod(method_val);registerCustomClassMethod(std::move(method));returnmethod_val;}};template<typenameCurClass,typename...CtorArgs>c10::IValuemake_custom_class(CtorArgs&&...args){autouserClassInstance=c10::make_intrusive<CurClass>(std::forward<CtorArgs>(args)...);returnc10::IValue(std::move(userClassInstance));}// Alternative api for creating a torchbind class over torch::class_ this api is// preferred to prevent size regressions on Edge usecases. Must be used in// conjunction with TORCH_SELECTIVE_CLASS macro aka// selective_class<foo>("foo_namespace", TORCH_SELECTIVE_CLASS("foo"))template<classCurClass>inlineclass_<CurClass>selective_class_(conststd::string&namespace_name,detail::SelectiveStr<true>className){autoclass_name=std::string(className.operatorconstchar*());returntorch::class_<CurClass>(namespace_name,class_name);}template<classCurClass>inlinedetail::ClassNotSelectedselective_class_(conststd::string&/*unused*/,detail::SelectiveStr<false>/*unused*/){returndetail::ClassNotSelected();}// jit namespace for backward-compatibility// We previously defined everything in torch::jit but moved it out to// better reflect that these features are not limited only to TorchScriptnamespacejit{using::torch::class_;using::torch::getCustomClass;using::torch::init;using::torch::isCustomClass;}// namespace jittemplate<classCurClass>inlineclass_<CurClass>Library::class_(conststd::string&className){TORCH_CHECK(kind_==DEF||kind_==FRAGMENT,"class_(\"",className,"\"): Cannot define a class inside of a TORCH_LIBRARY_IMPL block. ""All class_()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. ""(Error occurred at ",file_,":",line_,")");TORCH_INTERNAL_ASSERT(ns_.has_value(),file_,":",line_);returntorch::class_<CurClass>(*ns_,className);}conststd::unordered_set<std::string>getAllCustomClassesNames();template<classCurClass>inlineclass_<CurClass>Library::class_(detail::SelectiveStr<true>className){autoclass_name=std::string(className.operatorconstchar*());TORCH_CHECK(kind_==DEF||kind_==FRAGMENT,"class_(\"",class_name,"\"): Cannot define a class inside of a TORCH_LIBRARY_IMPL block. ""All class_()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. ""(Error occurred at ",file_,":",line_,")");TORCH_INTERNAL_ASSERT(ns_.has_value(),file_,":",line_);returntorch::class_<CurClass>(*ns_,class_name);}template<classCurClass>inlinedetail::ClassNotSelectedLibrary::class_(detail::SelectiveStr<false>/*unused*/){returndetail::ClassNotSelected();}}// namespace torch