diff --git a/deepmd/entrypoints/convert.py b/deepmd/entrypoints/convert.py index 4bf514fe51..25f2271cdb 100644 --- a/deepmd/entrypoints/convert.py +++ b/deepmd/entrypoints/convert.py @@ -1,4 +1,4 @@ -from deepmd.utils.convert import convert_13_to_20 +from deepmd.utils.convert import convert_13_to_20, convert_12_to_20 def convert( *, @@ -7,7 +7,9 @@ def convert( output_model: str, **kwargs, ): - if FROM == '1.3': + if FROM == '1.2': + convert_12_to_20(input_model, output_model) + elif FROM == '1.3': convert_13_to_20(input_model, output_model) else: raise RuntimeError('unsupported model version ' + FROM) diff --git a/deepmd/entrypoints/main.py b/deepmd/entrypoints/main.py index 04dc245271..df2db34923 100644 --- a/deepmd/entrypoints/main.py +++ b/deepmd/entrypoints/main.py @@ -361,7 +361,7 @@ def parse_args(args: Optional[List[str]] = None): ) # * convert models - # supported: 1.3->2.0 + # supported: 1.2->2.0, 1.3->2.0 parser_transform = subparsers.add_parser( 'convert-from', parents=[parser_log], @@ -370,7 +370,7 @@ def parse_args(args: Optional[List[str]] = None): parser_transform.add_argument( 'FROM', type = str, - choices = ['1.3'], + choices = ['1.2', '1.3'], help="The original model compatibility", ) parser_transform.add_argument( diff --git a/deepmd/utils/convert.py b/deepmd/utils/convert.py index 0d9c39df88..6ace6a2132 100644 --- a/deepmd/utils/convert.py +++ b/deepmd/utils/convert.py @@ -11,6 +11,15 @@ def convert_13_to_20(input_model: str, output_model: str): os.remove('frozen_model.pbtxt') print("the converted output model (2.0 support) is saved in %s" % output_model) +def convert_12_to_20(input_model: str, output_model: str): + convert_pb_to_pbtxt(input_model, 'frozen_model.pbtxt') + convert_dp12_to_dp13('frozen_model.pbtxt') + convert_dp13_to_dp20('frozen_model.pbtxt') + convert_pbtxt_to_pb('frozen_model.pbtxt', output_model) + if os.path.isfile('frozen_model.pbtxt'): + os.remove('frozen_model.pbtxt') + print("the converted output model (2.0 support) is saved in %s" % output_model) + def convert_pb_to_pbtxt(pbfile: str, pbtxtfile: str): with gfile.FastGFile(pbfile, 'rb') as f: graph_def = tf.GraphDef() @@ -26,6 +35,28 @@ def convert_pbtxt_to_pb(pbtxtfile: str, pbfile: str): text_format.Merge(file_content, graph_def) tf.train.write_graph(graph_def, './', pbfile, as_text=False) +def convert_dp12_to_dp13(file): + file_data = "" + with open(file, "r", encoding="utf-8") as f: + ii = 0 + lines = f.readlines() + while (ii < len(lines)): + line = lines[ii] + file_data += line + ii+=1 + if 'name' in line and ('DescrptSeA' in line or 'ProdForceSeA' in line or 'ProdVirialSeA' in line): + while not('attr' in lines[ii] and '{' in lines[ii]): + file_data += lines[ii] + ii+=1 + file_data += ' attr {\n' + file_data += ' key: \"T\"\n' + file_data += ' value {\n' + file_data += ' type: DT_DOUBLE\n' + file_data += ' }\n' + file_data += ' }\n' + with open(file, "w", encoding="utf-8") as f: + f.write(file_data) + def convert_dp13_to_dp20(fname: str): with open(fname) as fp: file_content = fp.read() diff --git a/doc/getting-started.md b/doc/getting-started.md index 6a10a49eee..76ed8acd5b 100644 --- a/doc/getting-started.md +++ b/doc/getting-started.md @@ -306,6 +306,8 @@ The model compression method requires that the version of DeePMD-kit used in ori ## Model inference +Note that the model for inference is required to be compatible with the DeePMD-kit package. See [Model compatibility](troubleshooting/model-compatability.md) for details. + ### Python interface One may use the python interface of DeePMD-kit for model inference, an example is given as follows ```python @@ -360,6 +362,8 @@ and then run the program: ## Run MD +Note that the model for MD simulations is required to be compatible with the DeePMD-kit package. See [Model compatibility](troubleshooting/model-compatability.md) for details. + ### Run MD with LAMMPS Include deepmd in the pair_style diff --git a/doc/troubleshooting/model-compatability.md b/doc/troubleshooting/model-compatability.md index 3ae1338cc9..fcc73f6cb3 100644 --- a/doc/troubleshooting/model-compatability.md +++ b/doc/troubleshooting/model-compatability.md @@ -1,5 +1,16 @@ -# Model compatability +# Model compatibility -When the version of DeePMD-kit used to training model is different from the that of DeePMD-kit running MDs, one has the problem of model compatability. +When the version of DeePMD-kit used to training model is different from the that of DeePMD-kit running MDs, one has the problem of model compatibility. DeePMD-kit guarantees that the codes with the same major and minor revisions are compatible. That is to say v0.12.5 is compatible to v0.12.0, but is not compatible to v0.11.0 nor v1.0.0. + +One can execuate `dp convert-from` to convert an old model to a new one. + +| Model version | v0.12 | v1.0 | v1.1 | v1.2 | v1.3 | v2.0 | +|:-:|:-----------:|:----------:|:----------:|:----------:|:----------:|:----------:| +| Compatibility | 😢 | 😢 | 😢 | 😊 | 😊 | 😄 | + +**Legend**: +- 😄: The model is compatible with the DeePMD-kit package. +- 😊: The model is incompatible with the DeePMD-kit package, but one can execuate `dp convert-from` to convert an old model to v2.0. +- 😢: The model is incompatible with the DeePMD-kit package, and there is no way to convert models. diff --git a/source/op/descrpt.cc b/source/op/descrpt.cc index 48fc4f3943..10ba125594 100644 --- a/source/op/descrpt.cc +++ b/source/op/descrpt.cc @@ -7,7 +7,7 @@ typedef double boxtensor_t ; typedef double compute_t; REGISTER_OP("Descrpt") -.Attr("T: {float, double}") +.Attr("T: {float, double} = DT_DOUBLE") .Input("coord: T") .Input("type: int32") .Input("natoms: int32") diff --git a/source/op/descrpt_se_a_ef.cc b/source/op/descrpt_se_a_ef.cc index 7f07cc84b9..3ba41624d9 100644 --- a/source/op/descrpt_se_a_ef.cc +++ b/source/op/descrpt_se_a_ef.cc @@ -8,7 +8,7 @@ typedef double boxtensor_t ; typedef double compute_t; REGISTER_OP("DescrptSeAEf") -.Attr("T: {float, double}") +.Attr("T: {float, double} = DT_DOUBLE") .Input("coord: T") .Input("type: int32") .Input("natoms: int32") diff --git a/source/op/descrpt_se_a_ef_para.cc b/source/op/descrpt_se_a_ef_para.cc index 6e38e24a86..2cb3b3445c 100644 --- a/source/op/descrpt_se_a_ef_para.cc +++ b/source/op/descrpt_se_a_ef_para.cc @@ -7,7 +7,7 @@ typedef double boxtensor_t ; typedef double compute_t; REGISTER_OP("DescrptSeAEfPara") -.Attr("T: {float, double}") +.Attr("T: {float, double} = DT_DOUBLE") .Input("coord: T") .Input("type: int32") .Input("natoms: int32") diff --git a/source/op/descrpt_se_a_ef_vert.cc b/source/op/descrpt_se_a_ef_vert.cc index 9b08f87ce6..615b153bf3 100644 --- a/source/op/descrpt_se_a_ef_vert.cc +++ b/source/op/descrpt_se_a_ef_vert.cc @@ -7,7 +7,7 @@ typedef double boxtensor_t ; typedef double compute_t; REGISTER_OP("DescrptSeAEfVert") -.Attr("T: {float, double}") +.Attr("T: {float, double} = DT_DOUBLE") .Input("coord: T") .Input("type: int32") .Input("natoms: int32") diff --git a/source/op/ewald_recp.cc b/source/op/ewald_recp.cc index ae3aa84bc1..9159dc5931 100644 --- a/source/op/ewald_recp.cc +++ b/source/op/ewald_recp.cc @@ -5,7 +5,7 @@ typedef double boxtensor_t ; typedef double compute_t; REGISTER_OP("EwaldRecp") -.Attr("T: {float, double}") +.Attr("T: {float, double} = DT_DOUBLE") .Input("coord: T") .Input("charge: T") .Input("natoms: int32") diff --git a/source/op/gelu_multi_device.cc b/source/op/gelu_multi_device.cc index 953d89f55a..508f60ccef 100644 --- a/source/op/gelu_multi_device.cc +++ b/source/op/gelu_multi_device.cc @@ -2,18 +2,18 @@ #include "gelu.h" REGISTER_OP("Gelu") - .Attr("T: {float, double}") + .Attr("T: {float, double} = DT_DOUBLE") .Input("x: T") .Output("output: T"); REGISTER_OP("GeluGrad") - .Attr("T: {float, double}") + .Attr("T: {float, double} = DT_DOUBLE") .Input("dy: T") .Input("x: T") .Output("output: T"); REGISTER_OP("GeluGradGrad") - .Attr("T: {float, double}") + .Attr("T: {float, double} = DT_DOUBLE") .Input("dy: T") .Input("dy_: T") .Input("x: T") diff --git a/source/op/map_aparam.cc b/source/op/map_aparam.cc index 3eba13990a..f1c98bdc9c 100644 --- a/source/op/map_aparam.cc +++ b/source/op/map_aparam.cc @@ -2,7 +2,7 @@ #include "map_aparam.h" REGISTER_OP("MapAparam") -.Attr("T: {float, double}") +.Attr("T: {float, double} = DT_DOUBLE") .Input("aparam: T") .Input("nlist: int32") .Input("natoms: int32") diff --git a/source/op/neighbor_stat.cc b/source/op/neighbor_stat.cc index fd9ae776e7..11f991b4b7 100644 --- a/source/op/neighbor_stat.cc +++ b/source/op/neighbor_stat.cc @@ -5,7 +5,7 @@ typedef double boxtensor_t ; typedef double compute_t; REGISTER_OP("NeighborStat") - .Attr("T: {float, double}") + .Attr("T: {float, double} = DT_DOUBLE") .Input("coord: T") .Input("type: int32") .Input("natoms: int32") diff --git a/source/op/pair_tab.cc b/source/op/pair_tab.cc index fb3689b5a8..e09ef460b4 100644 --- a/source/op/pair_tab.cc +++ b/source/op/pair_tab.cc @@ -2,7 +2,7 @@ #include "pair_tab.h" REGISTER_OP("PairTab") -.Attr("T: {float, double}") +.Attr("T: {float, double} = DT_DOUBLE") .Input("table_info: double") .Input("table_data: double") .Input("type: int32") diff --git a/source/op/prod_env_mat_multi_device.cc b/source/op/prod_env_mat_multi_device.cc index 6320f1f501..7c7130cda0 100644 --- a/source/op/prod_env_mat_multi_device.cc +++ b/source/op/prod_env_mat_multi_device.cc @@ -6,7 +6,7 @@ #include "prod_env_mat.h" REGISTER_OP("ProdEnvMatA") - .Attr("T: {float, double}") + .Attr("T: {float, double} = DT_DOUBLE") .Input("coord: T") //atomic coordinates .Input("type: int32") //atomic type .Input("natoms: int32") //local atomic number; each type atomic number; daizheyingxiangqude atomic numbers @@ -27,7 +27,9 @@ REGISTER_OP("ProdEnvMatA") // an alias of ProdEnvMatA -- Compatible with v1.3 REGISTER_OP("DescrptSeA") - .Attr("T: {float, double}") + .Attr("T: {float, double} = DT_DOUBLE") + // give a default value to T, compatible with v1.2 + // See https://www.tensorflow.org/guide/create_op#backwards_compatibility .Input("coord: T") .Input("type: int32") .Input("natoms: int32") @@ -46,7 +48,7 @@ REGISTER_OP("DescrptSeA") .Output("nlist: int32"); REGISTER_OP("ProdEnvMatR") - .Attr("T: {float, double}") + .Attr("T: {float, double} = DT_DOUBLE") .Input("coord: T") .Input("type: int32") .Input("natoms: int32") @@ -64,7 +66,7 @@ REGISTER_OP("ProdEnvMatR") // an alias of ProdEnvMatR -- Compatible with v1.3 REGISTER_OP("DescrptSeR") - .Attr("T: {float, double}") + .Attr("T: {float, double} = DT_DOUBLE") .Input("coord: T") .Input("type: int32") .Input("natoms: int32") diff --git a/source/op/prod_force.cc b/source/op/prod_force.cc index e2c01cc211..307d00a85d 100644 --- a/source/op/prod_force.cc +++ b/source/op/prod_force.cc @@ -1,7 +1,7 @@ #include "custom_op.h" REGISTER_OP("ProdForce") -.Attr("T: {float, double}") +.Attr("T: {float, double} = DT_DOUBLE") .Input("net_deriv: T") .Input("in_deriv: T") .Input("nlist: int32") diff --git a/source/op/prod_force_grad.cc b/source/op/prod_force_grad.cc index fff7afd25b..52c8ed845f 100644 --- a/source/op/prod_force_grad.cc +++ b/source/op/prod_force_grad.cc @@ -1,7 +1,7 @@ #include "custom_op.h" REGISTER_OP("ProdForceGrad") -.Attr("T: {float, double}") +.Attr("T: {float, double} = DT_DOUBLE") .Input("grad: T") .Input("net_deriv: T") .Input("in_deriv: T") diff --git a/source/op/prod_force_grad_multi_device.cc b/source/op/prod_force_grad_multi_device.cc index 2dae7c1a0b..5aff4bbbef 100644 --- a/source/op/prod_force_grad_multi_device.cc +++ b/source/op/prod_force_grad_multi_device.cc @@ -2,7 +2,7 @@ #include "prod_force_grad.h" REGISTER_OP("ProdForceSeAGrad") - .Attr("T: {float, double}") + .Attr("T: {float, double} = DT_DOUBLE") .Input("grad: T") .Input("net_deriv: T") .Input("in_deriv: T") @@ -13,7 +13,7 @@ REGISTER_OP("ProdForceSeAGrad") .Output("grad_net: T"); REGISTER_OP("ProdForceSeRGrad") - .Attr("T: {float, double}") + .Attr("T: {float, double} = DT_DOUBLE") .Input("grad: T") .Input("net_deriv: T") .Input("in_deriv: T") diff --git a/source/op/prod_force_multi_device.cc b/source/op/prod_force_multi_device.cc index 748971751e..63e6945906 100644 --- a/source/op/prod_force_multi_device.cc +++ b/source/op/prod_force_multi_device.cc @@ -2,7 +2,7 @@ #include "prod_force.h" REGISTER_OP("ProdForceSeA") - .Attr("T: {float, double}") + .Attr("T: {float, double} = DT_DOUBLE") .Input("net_deriv: T") .Input("in_deriv: T") .Input("nlist: int32") @@ -12,7 +12,7 @@ REGISTER_OP("ProdForceSeA") .Output("force: T"); REGISTER_OP("ProdForceSeR") - .Attr("T: {float, double}") + .Attr("T: {float, double} = DT_DOUBLE") .Input("net_deriv: T") .Input("in_deriv: T") .Input("nlist: int32") diff --git a/source/op/prod_force_se_a_grad.cc b/source/op/prod_force_se_a_grad.cc index 59878bf7d9..7617c244ed 100644 --- a/source/op/prod_force_se_a_grad.cc +++ b/source/op/prod_force_se_a_grad.cc @@ -2,7 +2,7 @@ #include "prod_force_grad.h" REGISTER_OP("ProdForceSeAGrad") -.Attr("T: {float, double}") +.Attr("T: {float, double} = DT_DOUBLE") .Input("grad: T") .Input("net_deriv: T") .Input("in_deriv: T") diff --git a/source/op/prod_force_se_r_grad.cc b/source/op/prod_force_se_r_grad.cc index be8ebec213..9fff3724ed 100644 --- a/source/op/prod_force_se_r_grad.cc +++ b/source/op/prod_force_se_r_grad.cc @@ -2,7 +2,7 @@ #include "prod_force_grad.h" REGISTER_OP("ProdForceSeRGrad") -.Attr("T: {float, double}") +.Attr("T: {float, double} = DT_DOUBLE") .Input("grad: T") .Input("net_deriv: T") .Input("in_deriv: T") diff --git a/source/op/prod_virial.cc b/source/op/prod_virial.cc index 65dc329b4c..d83ab27225 100644 --- a/source/op/prod_virial.cc +++ b/source/op/prod_virial.cc @@ -1,7 +1,7 @@ #include "custom_op.h" REGISTER_OP("ProdVirial") -.Attr("T: {float, double}") +.Attr("T: {float, double} = DT_DOUBLE") .Input("net_deriv: T") .Input("in_deriv: T") .Input("rij: T") diff --git a/source/op/prod_virial_grad.cc b/source/op/prod_virial_grad.cc index afe9a15382..d07a661cb9 100644 --- a/source/op/prod_virial_grad.cc +++ b/source/op/prod_virial_grad.cc @@ -1,7 +1,7 @@ #include "custom_op.h" REGISTER_OP("ProdVirialGrad") -.Attr("T: {float, double}") +.Attr("T: {float, double} = DT_DOUBLE") .Input("grad: T") .Input("net_deriv: T") .Input("in_deriv: T") diff --git a/source/op/prod_virial_grad_multi_device.cc b/source/op/prod_virial_grad_multi_device.cc index a0dd7ddb99..7a37da9b38 100644 --- a/source/op/prod_virial_grad_multi_device.cc +++ b/source/op/prod_virial_grad_multi_device.cc @@ -2,7 +2,7 @@ #include "prod_virial_grad.h" REGISTER_OP("ProdVirialSeAGrad") - .Attr("T: {float, double}") + .Attr("T: {float, double} = DT_DOUBLE") .Input("grad: T") .Input("net_deriv: T") .Input("in_deriv: T") @@ -14,7 +14,7 @@ REGISTER_OP("ProdVirialSeAGrad") .Output("grad_net: T"); REGISTER_OP("ProdVirialSeRGrad") - .Attr("T: {float, double}") + .Attr("T: {float, double} = DT_DOUBLE") .Input("grad: T") .Input("net_deriv: T") .Input("in_deriv: T") diff --git a/source/op/prod_virial_multi_device.cc b/source/op/prod_virial_multi_device.cc index 00537179c9..02c212a2d9 100644 --- a/source/op/prod_virial_multi_device.cc +++ b/source/op/prod_virial_multi_device.cc @@ -2,7 +2,7 @@ #include "prod_virial.h" REGISTER_OP("ProdVirialSeA") - .Attr("T: {float, double}") + .Attr("T: {float, double} = DT_DOUBLE") .Input("net_deriv: T") .Input("in_deriv: T") .Input("rij: T") @@ -14,7 +14,7 @@ REGISTER_OP("ProdVirialSeA") .Output("atom_virial: T"); REGISTER_OP("ProdVirialSeR") - .Attr("T: {float, double}") + .Attr("T: {float, double} = DT_DOUBLE") .Input("net_deriv: T") .Input("in_deriv: T") .Input("rij: T") diff --git a/source/op/prod_virial_se_a_grad.cc b/source/op/prod_virial_se_a_grad.cc index 2e6056c09c..cb76d29512 100644 --- a/source/op/prod_virial_se_a_grad.cc +++ b/source/op/prod_virial_se_a_grad.cc @@ -2,7 +2,7 @@ #include "prod_virial_grad.h" REGISTER_OP("ProdVirialSeAGrad") -.Attr("T: {float, double}") +.Attr("T: {float, double} = DT_DOUBLE") .Input("grad: T") .Input("net_deriv: T") .Input("in_deriv: T") diff --git a/source/op/prod_virial_se_r_grad.cc b/source/op/prod_virial_se_r_grad.cc index 57482f0f8a..247f2ee909 100644 --- a/source/op/prod_virial_se_r_grad.cc +++ b/source/op/prod_virial_se_r_grad.cc @@ -2,7 +2,7 @@ #include "prod_virial_grad.h" REGISTER_OP("ProdVirialSeRGrad") -.Attr("T: {float, double}") +.Attr("T: {float, double} = DT_DOUBLE") .Input("grad: T") .Input("net_deriv: T") .Input("in_deriv: T") diff --git a/source/op/soft_min.cc b/source/op/soft_min.cc index cae371fc70..c30d9c409a 100644 --- a/source/op/soft_min.cc +++ b/source/op/soft_min.cc @@ -3,7 +3,7 @@ #include "soft_min_switch.h" REGISTER_OP("SoftMinSwitch") -.Attr("T: {float, double}") +.Attr("T: {float, double} = DT_DOUBLE") .Input("type: int32") .Input("rij: T") .Input("nlist: int32") diff --git a/source/op/soft_min_force.cc b/source/op/soft_min_force.cc index 15e5e3b41d..7d09da6613 100644 --- a/source/op/soft_min_force.cc +++ b/source/op/soft_min_force.cc @@ -2,7 +2,7 @@ #include "soft_min_switch_force.h" REGISTER_OP("SoftMinForce") -.Attr("T: {float, double}") +.Attr("T: {float, double} = DT_DOUBLE") .Input("du: T") .Input("sw_deriv: T") .Input("nlist: int32") diff --git a/source/op/soft_min_force_grad.cc b/source/op/soft_min_force_grad.cc index 6a161e4f4d..a7328734b6 100644 --- a/source/op/soft_min_force_grad.cc +++ b/source/op/soft_min_force_grad.cc @@ -2,7 +2,7 @@ #include "soft_min_switch_force_grad.h" REGISTER_OP("SoftMinForceGrad") -.Attr("T: {float, double}") +.Attr("T: {float, double} = DT_DOUBLE") .Input("grad: T") .Input("du: T") .Input("sw_deriv: T") diff --git a/source/op/soft_min_virial.cc b/source/op/soft_min_virial.cc index 3dcc2e6daa..3273160fe3 100644 --- a/source/op/soft_min_virial.cc +++ b/source/op/soft_min_virial.cc @@ -2,7 +2,7 @@ #include "soft_min_switch_virial.h" REGISTER_OP("SoftMinVirial") -.Attr("T: {float, double}") +.Attr("T: {float, double} = DT_DOUBLE") .Input("du: T") .Input("sw_deriv: T") .Input("rij: T") diff --git a/source/op/soft_min_virial_grad.cc b/source/op/soft_min_virial_grad.cc index c5c5399195..034aeb7a09 100644 --- a/source/op/soft_min_virial_grad.cc +++ b/source/op/soft_min_virial_grad.cc @@ -2,7 +2,7 @@ #include "soft_min_switch_virial_grad.h" REGISTER_OP("SoftMinVirialGrad") -.Attr("T: {float, double}") +.Attr("T: {float, double} = DT_DOUBLE") .Input("grad: T") .Input("du: T") .Input("sw_deriv: T") diff --git a/source/op/tabulate_multi_device.cc b/source/op/tabulate_multi_device.cc index 9d54bd18a8..6fafa5698e 100644 --- a/source/op/tabulate_multi_device.cc +++ b/source/op/tabulate_multi_device.cc @@ -2,7 +2,7 @@ #include "tabulate.h" REGISTER_OP("TabulateFusion") - .Attr("T: {float, double}") + .Attr("T: {float, double} = DT_DOUBLE") .Input("table: T") .Input("table_info: T") .Input("em_x: T") @@ -11,7 +11,7 @@ REGISTER_OP("TabulateFusion") .Output("descriptor: T"); REGISTER_OP("TabulateFusionGrad") - .Attr("T: {float, double}") + .Attr("T: {float, double} = DT_DOUBLE") .Input("table: T") .Input("table_info: T") .Input("em_x: T") diff --git a/source/op/unaggregated_grad.cc b/source/op/unaggregated_grad.cc index 5f23d639a5..56502efc55 100644 --- a/source/op/unaggregated_grad.cc +++ b/source/op/unaggregated_grad.cc @@ -3,27 +3,27 @@ #include "neighbor_list.h" REGISTER_OP("UnaggregatedDyDxS") - .Attr("T: {float, double}") + .Attr("T: {float, double} = DT_DOUBLE") .Input("y: T") .Input("w: T") .Output("dy_dx: T"); REGISTER_OP("UnaggregatedDyDx") - .Attr("T: {float, double}") + .Attr("T: {float, double} = DT_DOUBLE") .Input("z: T") .Input("w: T") .Input("dy_dx: T") .Output("dz_dx: T"); REGISTER_OP("UnaggregatedDy2DxS") - .Attr("T: {float, double}") + .Attr("T: {float, double} = DT_DOUBLE") .Input("y: T") .Input("dy: T") .Input("w: T") .Output("dy2_dx: T"); REGISTER_OP("UnaggregatedDy2Dx") - .Attr("T: {float, double}") + .Attr("T: {float, double} = DT_DOUBLE") .Input("z: T") .Input("w: T") .Input("dz_dx: T")