diff --git a/source/python.js b/source/python.js index b006145e1d..28c66b9980 100644 --- a/source/python.js +++ b/source/python.js @@ -6163,7 +6163,7 @@ python.Execution = class { }); this.registerType('torch.ClassType', class extends torch.Type { constructor(qualified_name, cu, is_module) { - super(); + super('ClassType'); this._qualified_name = qualified_name; this._is_module = is_module; this._attributes = new Map(); @@ -6328,6 +6328,9 @@ python.Execution = class { getValueType() { return this._value; } + __str__() { + return `Dict(${this.getKeyType().toString()}, ${this.getValueType().toString()})`; + } }); this.registerType('torch.DeviceObjType', class extends torch.Type { constructor() { diff --git a/source/pytorch-metadata.json b/source/pytorch-metadata.json index 3dc2bd3216..1b744f0185 100755 --- a/source/pytorch-metadata.json +++ b/source/pytorch-metadata.json @@ -222,6 +222,15 @@ { "type": "complex" } ] }, + { + "name": "aten::ComplexImplicit", + "inputs": [ + { "name": "a", "type": "Tensor" } + ], + "outputs": [ + { "type": "complex" } + ] + }, { "name": "aten::Float.Scalar", "inputs": [ @@ -267,6 +276,15 @@ { "type": "float32" } ] }, + { + "name": "aten::FloatImplicit", + "inputs": [ + { "name": "a", "type": "Tensor" } + ], + "outputs": [ + { "type": "float32" } + ] + }, { "name": "aten::Int.Scalar", "inputs": [ @@ -312,6 +330,24 @@ { "type": "int64" } ] }, + { + "name": "aten::IntImplicit", + "inputs": [ + { "name": "a", "type": "Tensor" } + ], + "outputs": [ + { "type": "int64" } + ] + }, + { + "name": "aten::ScalarImplicit", + "inputs": [ + { "name": "a", "type": "Tensor" } + ], + "outputs": [ + { "type": "Scalar" } + ] + }, { "name": "aten::__and__.Scalar", "inputs": [ @@ -332,6 +368,106 @@ { "type": "Tensor" } ] }, + { + "name": "aten::__and__.bool", + "inputs": [ + { "name": "a", "type": "boolean" }, + { "name": "b", "type": "boolean" } + ], + "outputs": [ + { "type": "boolean" } + ] + }, + { + "name": "aten::__and__.int", + "inputs": [ + { "name": "a", "type": "int64" }, + { "name": "b", "type": "int64" } + ], + "outputs": [ + { "type": "int64" } + ] + }, + { + "name": "aten::__getitem__.Dict_Tensor", + "inputs": [ + { "name": "self", "type": "Dict(Tensor, t)" }, + { "name": "key", "type": "Tensor" } + ], + "outputs": [ + { "type": "t" } + ] + }, + { + "name": "aten::__getitem__.Dict_bool", + "inputs": [ + { "name": "self", "type": "Dict(boolean, t)" }, + { "name": "key", "type": "boolean" } + ], + "outputs": [ + { "type": "t" } + ] + }, + { + "name": "aten::__getitem__.Dict_complex", + "inputs": [ + { "name": "self", "type": "Dict(complex, t)" }, + { "name": "key", "type": "complex" } + ], + "outputs": [ + { "type": "t" } + ] + }, + { + "name": "aten::__getitem__.Dict_float", + "inputs": [ + { "name": "self", "type": "Dict(float32, t)" }, + { "name": "key", "type": "float32" } + ], + "outputs": [ + { "type": "t" } + ] + }, + { + "name": "aten::__getitem__.Dict_int", + "inputs": [ + { "name": "self", "type": "Dict(int64, t)" }, + { "name": "key", "type": "int64" } + ], + "outputs": [ + { "type": "t" } + ] + }, + { + "name": "aten::__getitem__.Dict_str", + "inputs": [ + { "name": "self", "type": "Dict(string, t)" }, + { "name": "key", "type": "string" } + ], + "outputs": [ + { "type": "t" } + ] + }, + { + "name": "aten::__getitem__.str", + "inputs": [ + { "name": "s", "type": "string" }, + { "name": "index", "type": "int64" } + ], + "outputs": [ + { "type": "string" } + ] + }, + { + "name": "aten::__getitem__.t", + "inputs": [ + { "name": "list", "type": "t[]" }, + { "name": "idx", "type": "int64" } + ], + "outputs": [ + { "type": "t" } + ] + }, { "name": "aten::__iand__.Scalar", "inputs": [ @@ -502,6 +638,15 @@ { "type": "Tensor" } ] }, + { + "name": "aten::__not__", + "inputs": [ + { "name": "self", "type": "boolean" } + ], + "outputs": [ + { "type": "boolean" } + ] + }, { "name": "aten::__or__.Scalar", "inputs": [ @@ -1021,6 +1166,40 @@ { "name": "save_rstd", "type": "Tensor" } ] }, + { + "name": "aten::_native_multi_head_attention", + "category": "Attention", + "inputs": [ + { "name": "query", "type": "Tensor" }, + { "name": "key", "type": "Tensor" }, + { "name": "value", "type": "Tensor" }, + { "name": "embed_dim", "type": "int64" }, + { "name": "num_head", "type": "int64" }, + { "name": "qkv_weight", "type": "Tensor" }, + { "name": "qkv_bias", "type": "Tensor" }, + { "name": "proj_weight", "type": "Tensor" }, + { "name": "proj_bias", "type": "Tensor" }, + { "name": "mask", "type": "Tensor", "optional": true, "default": null }, + { "name": "need_weights", "type": "boolean", "default": true }, + { "name": "average_attn_weights", "type": "boolean", "default": true }, + { "name": "mask_type", "type": "int64", "optional": true, "default": null } + ], + "outputs": [ + { "type": "Tensor" }, + { "type": "Tensor" } + ] + }, + { + "name": "aten::_nested_tensor_from_mask", + "inputs": [ + { "name": "t", "type": "Tensor" }, + { "name": "mask", "type": "Tensor" }, + { "name": "mask_check", "type": "boolean", "default": true } + ], + "outputs": [ + { "type": "Tensor" } + ] + }, { "name": "aten::_pack_padded_sequence", "inputs": [ @@ -1186,6 +1365,34 @@ { "type": "Tensor" } ] }, + { + "name": "aten::_transformer_encoder_layer_fwd", + "inputs": [ + { "name": "src", "type": "Tensor" }, + { "name": "embed_dim", "type": "int64" }, + { "name": "num_heads", "type": "int64" }, + { "name": "qkv_weight", "type": "Tensor" }, + { "name": "qkv_bias", "type": "Tensor" }, + { "name": "proj_weight", "type": "Tensor" }, + { "name": "proj_bias", "type": "Tensor" }, + { "name": "use_gelu", "type": "boolean" }, + { "name": "norm_first", "type": "boolean" }, + { "name": "eps", "type": "float32" }, + { "name": "norm_weight_1", "type": "Tensor" }, + { "name": "norm_bias_1", "type": "Tensor" }, + { "name": "norm_weight_2", "type": "Tensor" }, + { "name": "norm_bias_2", "type": "Tensor" }, + { "name": "ffn_weight_1", "type": "Tensor" }, + { "name": "ffn_bias_1", "type": "Tensor" }, + { "name": "ffn_weight_2", "type": "Tensor" }, + { "name": "ffn_bias_2", "type": "Tensor" }, + { "name": "mask", "type": "Tensor", "optional": true, "default": null }, + { "name": "mask_type", "type": "int64", "optional": true, "default": null } + ], + "outputs": [ + { "type": "Tensor" } + ] + }, { "name": "aten::_unique2", "inputs": [ @@ -1404,6 +1611,7 @@ }, { "name": "aten::adaptive_avg_pool2d.out", + "category": "Pool", "inputs": [ { "name": "self", "type": "Tensor" }, { "name": "output_size", "type": "SymInt[2]" } @@ -2027,7 +2235,7 @@ "name": "aten::any.dimname", "inputs": [ { "name": "self", "type": "Tensor" }, - { "name": "dim", "type": "Dimname" }, + { "name": "dim", "type": "string" }, { "name": "keepdim", "type": "boolean", "default": false } ], "outputs": [ @@ -2038,7 +2246,7 @@ "name": "aten::any.dimname_out", "inputs": [ { "name": "self", "type": "Tensor" }, - { "name": "dim", "type": "Dimname" }, + { "name": "dim", "type": "string" }, { "name": "keepdim", "type": "boolean", "default": false } ], "outputs": [ @@ -2105,6 +2313,16 @@ { "type": "boolean" } ] }, + { + "name": "aten::append.t", + "inputs": [ + { "name": "self", "type": "t[]" }, + { "name": "el", "type": "t" } + ], + "outputs": [ + { "type": "t[]" } + ] + }, { "name": "aten::arange", "inputs": [ @@ -3106,6 +3324,33 @@ { "type": "Tensor" } ] }, + { + "name": "aten::ceil.Scalar", + "inputs": [ + { "name": "a", "type": "Scalar" } + ], + "outputs": [ + { "type": "Scalar" } + ] + }, + { + "name": "aten::ceil.float", + "inputs": [ + { "name": "a", "type": "float32" } + ], + "outputs": [ + { "type": "int64" } + ] + }, + { + "name": "aten::ceil.int", + "inputs": [ + { "name": "a", "type": "int64" } + ], + "outputs": [ + { "type": "int64" } + ] + }, { "name": "aten::ceil.out", "inputs": [ @@ -4340,6 +4585,22 @@ { "type": "Tensor" } ] }, + { + "name": "aten::dict", + "inputs": [], + "outputs": [ + { "type": "Dict(string, Tensor)" } + ] + }, + { + "name": "aten::dict.Dict_str", + "inputs": [ + { "name": "self", "type": "Dict(string, t)" } + ], + "outputs": [ + { "type": "Dict(string, t)" } + ] + }, { "name": "aten::diff", "inputs": [ @@ -4386,6 +4647,16 @@ { "type": "Tensor" } ] }, + { + "name": "aten::div", + "inputs": [ + { "name": "a", "type": "Scalar" }, + { "name": "b", "type": "Scalar" } + ], + "outputs": [ + { "type": "float32" } + ] + }, { "name": "aten::div.Scalar", "inputs": [ @@ -4407,6 +4678,27 @@ { "type": "Tensor" } ] }, + { + "name": "aten::div.Scalar_mode_out", + "inputs": [ + { "name": "self", "type": "Tensor" }, + { "name": "other", "type": "Scalar" }, + { "name": "rounding_mode", "type": "string", "optional": true, "kwarg_only": true } + ], + "outputs": [ + { "type": "Tensor" } + ] + }, + { + "name": "aten::div.Scalar_out", + "inputs": [ + { "name": "self", "type": "Tensor" }, + { "name": "other", "type": "Scalar" } + ], + "outputs": [ + { "type": "Tensor" } + ] + }, { "name": "aten::div.Tensor", "inputs": [ @@ -4428,6 +4720,36 @@ { "type": "Tensor" } ] }, + { + "name": "aten::div.complex", + "inputs": [ + { "name": "a", "type": "complex" }, + { "name": "b", "type": "complex" } + ], + "outputs": [ + { "type": "complex" } + ] + }, + { + "name": "aten::div.float", + "inputs": [ + { "name": "a", "type": "float32" }, + { "name": "b", "type": "float32" } + ], + "outputs": [ + { "type": "float32" } + ] + }, + { + "name": "aten::div.int", + "inputs": [ + { "name": "a", "type": "int64" }, + { "name": "b", "type": "int64" } + ], + "outputs": [ + { "type": "float32" } + ] + }, { "name": "aten::div.out", "inputs": [ @@ -4839,6 +5161,16 @@ { "type": "Tensor" } ] }, + { + "name": "aten::eq", + "inputs": [ + { "name": "a", "type": "Scalar" }, + { "name": "b", "type": "Scalar" } + ], + "outputs": [ + { "type": "boolean" } + ] + }, { "name": "aten::eq.Scalar", "inputs": [ @@ -4910,25 +5242,75 @@ ] }, { - "name": "aten::eq.device", + "name": "aten::eq.complex", "inputs": [ - { "name": "a", "type": "Device" }, - { "name": "b", "type": "Device" } + { "name": "a", "type": "complex" }, + { "name": "b", "type": "complex" } ], "outputs": [ { "type": "boolean" } ] }, { - "name": "aten::eq.enum", + "name": "aten::eq.complex_float", "inputs": [ - { "name": "a", "type": "AnyEnumType" }, + { "name": "a", "type": "complex" }, + { "name": "b", "type": "float32" } + ], + "outputs": [ + { "type": "boolean" } + ] + }, + { + "name": "aten::eq.device", + "inputs": [ + { "name": "a", "type": "Device" }, + { "name": "b", "type": "Device" } + ], + "outputs": [ + { "type": "boolean" } + ] + }, + { + "name": "aten::eq.enum", + "inputs": [ + { "name": "a", "type": "AnyEnumType" }, { "name": "b", "type": "AnyEnumType" } ], "outputs": [ { "type": "boolean" } ] }, + { + "name": "aten::eq.float", + "inputs": [ + { "name": "a", "type": "float32" }, + { "name": "b", "type": "float32" } + ], + "outputs": [ + { "type": "boolean" } + ] + }, + { + "name": "aten::eq.float_complex", + "inputs": [ + { "name": "a", "type": "float32" }, + { "name": "b", "type": "complex" } + ], + "outputs": [ + { "type": "boolean" } + ] + }, + { + "name": "aten::eq.float_int", + "inputs": [ + { "name": "a", "type": "float32" }, + { "name": "b", "type": "int64" } + ], + "outputs": [ + { "type": "boolean" } + ] + }, { "name": "aten::eq.float_list", "inputs": [ @@ -4939,6 +5321,26 @@ { "type": "boolean" } ] }, + { + "name": "aten::eq.int", + "inputs": [ + { "name": "a", "type": "int64" }, + { "name": "b", "type": "int64" } + ], + "outputs": [ + { "type": "boolean" } + ] + }, + { + "name": "aten::eq.int_float", + "inputs": [ + { "name": "a", "type": "int64" }, + { "name": "b", "type": "float32" } + ], + "outputs": [ + { "type": "boolean" } + ] + }, { "name": "aten::eq.int_list", "inputs": [ @@ -4949,6 +5351,16 @@ { "type": "boolean" } ] }, + { + "name": "aten::eq.str", + "inputs": [ + { "name": "a", "type": "string" }, + { "name": "b", "type": "string" } + ], + "outputs": [ + { "type": "boolean" } + ] + }, { "name": "aten::eq.str_list", "inputs": [ @@ -5091,6 +5503,14 @@ { "type": "Tensor" } ] }, + { + "name": "aten::extend.t", + "inputs": [ + { "name": "self", "type": "t[]" }, + { "name": "other", "type": "t[]" } + ], + "outputs": [] + }, { "name": "aten::eye", "inputs": [ @@ -5765,6 +6185,33 @@ { "type": "Tensor" } ] }, + { + "name": "aten::floor.Scalar", + "inputs": [ + { "name": "a", "type": "Scalar" } + ], + "outputs": [ + { "type": "Scalar" } + ] + }, + { + "name": "aten::floor.float", + "inputs": [ + { "name": "a", "type": "float32" } + ], + "outputs": [ + { "type": "float32" } + ] + }, + { + "name": "aten::floor.int", + "inputs": [ + { "name": "a", "type": "int64" } + ], + "outputs": [ + { "type": "int64" } + ] + }, { "name": "aten::floor.out", "inputs": [ @@ -5833,6 +6280,56 @@ { "type": "Tensor" } ] }, + { + "name": "aten::floordiv.Scalar", + "inputs": [ + { "name": "a", "type": "Scalar" }, + { "name": "b", "type": "Scalar" } + ], + "outputs": [ + { "type": "Scalar" } + ] + }, + { + "name": "aten::floordiv.float", + "inputs": [ + { "name": "a", "type": "float32" }, + { "name": "b", "type": "float32" } + ], + "outputs": [ + { "type": "float32" } + ] + }, + { + "name": "aten::floordiv.float_int", + "inputs": [ + { "name": "a", "type": "float32" }, + { "name": "b", "type": "int64" } + ], + "outputs": [ + { "type": "float32" } + ] + }, + { + "name": "aten::floordiv.int", + "inputs": [ + { "name": "a", "type": "int64" }, + { "name": "b", "type": "int64" } + ], + "outputs": [ + { "type": "int64" } + ] + }, + { + "name": "aten::floordiv.int_float", + "inputs": [ + { "name": "a", "type": "int64" }, + { "name": "b", "type": "float32" } + ], + "outputs": [ + { "type": "float32" } + ] + }, { "name": "aten::fmod.Scalar", "inputs": [ @@ -5875,6 +6372,7 @@ }, { "name": "aten::format", + "is_vararg": true, "inputs": [ { "name": "self", "type": "string" } ], @@ -6343,6 +6841,18 @@ { "type": "Tensor" } ] }, + { + "name": "aten::grid_sampler.legacy", + "inputs": [ + { "name": "input", "type": "Tensor" }, + { "name": "grid", "type": "Tensor" }, + { "name": "interpolation_mode", "type": "int64" }, + { "name": "padding_mode", "type": "int64" } + ], + "outputs": [ + { "type": "Tensor" } + ] + }, { "name": "aten::group_norm", "category": "Normalization", @@ -6396,6 +6906,16 @@ { "name": "?", "type": "Tensor" } ] }, + { + "name": "aten::gt", + "inputs": [ + { "name": "a", "type": "Scalar" }, + { "name": "b", "type": "Scalar" } + ], + "outputs": [ + { "type": "boolean" } + ] + }, { "name": "aten::gt.Scalar", "inputs": [ @@ -6476,6 +6996,16 @@ { "type": "boolean" } ] }, + { + "name": "aten::gt.str", + "inputs": [ + { "name": "a", "type": "string" }, + { "name": "b", "type": "string" } + ], + "outputs": [ + { "type": "boolean" } + ] + }, { "name": "aten::hamming_window", "inputs": [ @@ -7368,6 +7898,60 @@ { "type": "string" } ] }, + { + "name": "aten::keys.Tensor", + "inputs": [ + { "name": "self", "type": "Dict(Tensor, t)" } + ], + "outputs": [ + { "type": "Tensor[]" } + ] + }, + { + "name": "aten::keys.bool", + "inputs": [ + { "name": "self", "type": "Dict(boolean, t)" } + ], + "outputs": [ + { "type": "boolean[]" } + ] + }, + { + "name": "aten::keys.complex", + "inputs": [ + { "name": "self", "type": "Dict(complex, t)" } + ], + "outputs": [ + { "type": "complex[]" } + ] + }, + { + "name": "aten::keys.float", + "inputs": [ + { "name": "self", "type": "Dict(float32, t)" } + ], + "outputs": [ + { "type": "float32[]" } + ] + }, + { + "name": "aten::keys.int", + "inputs": [ + { "name": "self", "type": "Dict(int64, t)" } + ], + "outputs": [ + { "type": "int64[]" } + ] + }, + { + "name": "aten::keys.str", + "inputs": [ + { "name": "self", "type": "Dict(string, t)" } + ], + "outputs": [ + { "type": "string[]" } + ] + }, { "name": "aten::kl_div", "inputs": [ @@ -7458,6 +8042,16 @@ { "type": "Tensor" } ] }, + { + "name": "aten::le", + "inputs": [ + { "name": "a", "type": "Scalar" }, + { "name": "b", "type": "Scalar" } + ], + "outputs": [ + { "type": "boolean" } + ] + }, { "name": "aten::le.Scalar", "inputs": [ @@ -7538,6 +8132,16 @@ { "type": "boolean" } ] }, + { + "name": "aten::le.str", + "inputs": [ + { "name": "a", "type": "string" }, + { "name": "b", "type": "string" } + ], + "outputs": [ + { "type": "boolean" } + ] + }, { "name": "aten::leaky_relu", "category": "Activation", @@ -8070,25 +8674,25 @@ ] }, { - "name": "aten::log", + "name": "aten::list", "inputs": [ - { "name": "self", "type": "Tensor" } + { "name": "t", "type": "string" } ], "outputs": [ - { "type": "Tensor" } + { "type": "string[]" } ] }, { - "name": "aten::log.out", + "name": "aten::list.t", "inputs": [ - { "name": "self", "type": "Tensor" } + { "name": "l", "type": "t[]" } ], "outputs": [ - { "type": "Tensor" } + { "type": "t[]" } ] }, { - "name": "aten::log10", + "name": "aten::log", "inputs": [ { "name": "self", "type": "Tensor" } ], @@ -8097,7 +8701,7 @@ ] }, { - "name": "aten::log10.out", + "name": "aten::log.out", "inputs": [ { "name": "self", "type": "Tensor" } ], @@ -8106,7 +8710,7 @@ ] }, { - "name": "aten::log10_", + "name": "aten::log10", "inputs": [ { "name": "self", "type": "Tensor" } ], @@ -8115,27 +8719,81 @@ ] }, { - "name": "aten::log1p", + "name": "aten::log10.Scalar", "inputs": [ - { "name": "self", "type": "Tensor" } + { "name": "a", "type": "Scalar" } ], "outputs": [ - { "type": "Tensor" } + { "type": "Scalar" } ] }, { - "name": "aten::log1p.out", + "name": "aten::log10.complex", "inputs": [ - { "name": "self", "type": "Tensor" } + { "name": "a", "type": "complex" } ], "outputs": [ - { "type": "Tensor" } + { "type": "complex" } ] }, { - "name": "aten::log1p_", + "name": "aten::log10.float", "inputs": [ - { "name": "self", "type": "Tensor" } + { "name": "a", "type": "float32" } + ], + "outputs": [ + { "type": "float32" } + ] + }, + { + "name": "aten::log10.int", + "inputs": [ + { "name": "a", "type": "int64" } + ], + "outputs": [ + { "type": "float32" } + ] + }, + { + "name": "aten::log10.out", + "inputs": [ + { "name": "self", "type": "Tensor" } + ], + "outputs": [ + { "type": "Tensor" } + ] + }, + { + "name": "aten::log10_", + "inputs": [ + { "name": "self", "type": "Tensor" } + ], + "outputs": [ + { "type": "Tensor" } + ] + }, + { + "name": "aten::log1p", + "inputs": [ + { "name": "self", "type": "Tensor" } + ], + "outputs": [ + { "type": "Tensor" } + ] + }, + { + "name": "aten::log1p.out", + "inputs": [ + { "name": "self", "type": "Tensor" } + ], + "outputs": [ + { "type": "Tensor" } + ] + }, + { + "name": "aten::log1p_", + "inputs": [ + { "name": "self", "type": "Tensor" } ], "outputs": [ { "type": "Tensor" } @@ -8754,6 +9412,16 @@ { "type": "Tensor" } ] }, + { + "name": "aten::lt", + "inputs": [ + { "name": "a", "type": "Scalar" }, + { "name": "b", "type": "Scalar" } + ], + "outputs": [ + { "type": "boolean" } + ] + }, { "name": "aten::lt.Scalar", "inputs": [ @@ -8794,6 +9462,56 @@ { "type": "Tensor" } ] }, + { + "name": "aten::lt.float", + "inputs": [ + { "name": "a", "type": "float32" }, + { "name": "b", "type": "float32" } + ], + "outputs": [ + { "type": "boolean" } + ] + }, + { + "name": "aten::lt.float_int", + "inputs": [ + { "name": "a", "type": "float32" }, + { "name": "b", "type": "int64" } + ], + "outputs": [ + { "type": "boolean" } + ] + }, + { + "name": "aten::lt.int", + "inputs": [ + { "name": "a", "type": "int64" }, + { "name": "b", "type": "int64" } + ], + "outputs": [ + { "type": "boolean" } + ] + }, + { + "name": "aten::lt.int_float", + "inputs": [ + { "name": "a", "type": "int64" }, + { "name": "b", "type": "float32" } + ], + "outputs": [ + { "type": "boolean" } + ] + }, + { + "name": "aten::lt.str", + "inputs": [ + { "name": "a", "type": "string" }, + { "name": "b", "type": "string" } + ], + "outputs": [ + { "type": "boolean" } + ] + }, { "name": "aten::manual_seed", "inputs": [ @@ -9632,6 +10350,16 @@ { "type": "Tensor" } ] }, + { + "name": "aten::mul", + "inputs": [ + { "name": "a", "type": "Scalar" }, + { "name": "b", "type": "Scalar" } + ], + "outputs": [ + { "type": "Scalar" } + ] + }, { "name": "aten::mul.Scalar", "inputs": [ @@ -9662,6 +10390,36 @@ { "type": "Tensor" } ] }, + { + "name": "aten::mul.float_int", + "inputs": [ + { "name": "a", "type": "float32" }, + { "name": "b", "type": "int64" } + ], + "outputs": [ + { "type": "float32" } + ] + }, + { + "name": "aten::mul.int", + "inputs": [ + { "name": "a", "type": "int64" }, + { "name": "b", "type": "int64" } + ], + "outputs": [ + { "type": "int64" } + ] + }, + { + "name": "aten::mul.int_float", + "inputs": [ + { "name": "a", "type": "int64" }, + { "name": "b", "type": "float32" } + ], + "outputs": [ + { "type": "float32" } + ] + }, { "name": "aten::mul.left_t", "inputs": [ @@ -10036,6 +10794,16 @@ { "type": "boolean" } ] }, + { + "name": "aten::ne.float", + "inputs": [ + { "name": "a", "type": "float32" }, + { "name": "b", "type": "float32" } + ], + "outputs": [ + { "type": "boolean" } + ] + }, { "name": "aten::ne.float_list", "inputs": [ @@ -10046,6 +10814,16 @@ { "type": "boolean" } ] }, + { + "name": "aten::ne.int", + "inputs": [ + { "name": "a", "type": "int64" }, + { "name": "b", "type": "int64" } + ], + "outputs": [ + { "type": "boolean" } + ] + }, { "name": "aten::ne.int_list", "inputs": [ @@ -10056,6 +10834,16 @@ { "type": "boolean" } ] }, + { + "name": "aten::ne.str", + "inputs": [ + { "name": "a", "type": "string" }, + { "name": "b", "type": "string" } + ], + "outputs": [ + { "type": "boolean" } + ] + }, { "name": "aten::ne.str_list", "inputs": [ @@ -10075,6 +10863,42 @@ { "type": "Tensor" } ] }, + { + "name": "aten::neg.Scalar", + "inputs": [ + { "name": "a", "type": "Scalar" } + ], + "outputs": [ + { "type": "Scalar" } + ] + }, + { + "name": "aten::neg.complex", + "inputs": [ + { "name": "a", "type": "complex" } + ], + "outputs": [ + { "type": "complex" } + ] + }, + { + "name": "aten::neg.float", + "inputs": [ + { "name": "a", "type": "float32" } + ], + "outputs": [ + { "type": "float32" } + ] + }, + { + "name": "aten::neg.int", + "inputs": [ + { "name": "a", "type": "int64" } + ], + "outputs": [ + { "type": "int64" } + ] + }, { "name": "aten::neg.out", "inputs": [ @@ -11647,6 +12471,16 @@ { "type": "Tensor" } ] }, + { + "name": "aten::remainder", + "inputs": [ + { "name": "a", "type": "Scalar" }, + { "name": "b", "type": "Scalar" } + ], + "outputs": [ + { "type": "Scalar" } + ] + }, { "name": "aten::remainder.Scalar", "inputs": [ @@ -11667,6 +12501,16 @@ { "type": "Tensor" } ] }, + { + "name": "aten::remainder.Scalar_Tensor_out", + "inputs": [ + { "name": "self", "type": "Scalar" }, + { "name": "other", "type": "Tensor" } + ], + "outputs": [ + { "type": "Tensor" } + ] + }, { "name": "aten::remainder.Scalar_out", "inputs": [ @@ -11698,7 +12542,7 @@ ] }, { - "name": "aten::remainder.float32", + "name": "aten::remainder.float", "inputs": [ { "name": "a", "type": "float32" }, { "name": "b", "type": "float32" } @@ -11707,6 +12551,16 @@ { "type": "float32" } ] }, + { + "name": "aten::remainder.float_int", + "inputs": [ + { "name": "a", "type": "float32" }, + { "name": "b", "type": "int64" } + ], + "outputs": [ + { "type": "float32" } + ] + }, { "name": "aten::remainder.int", "inputs": [ @@ -11717,6 +12571,16 @@ { "type": "int64" } ] }, + { + "name": "aten::remainder.int_float", + "inputs": [ + { "name": "a", "type": "int64" }, + { "name": "b", "type": "float32" } + ], + "outputs": [ + { "type": "float32" } + ] + }, { "name": "aten::remainder_.Scalar", "inputs": [ @@ -13280,43 +14144,79 @@ ] }, { - "name": "aten::sqrt.out", + "name": "aten::sqrt.Scalar", "inputs": [ - { "name": "self", "type": "Tensor" } + { "name": "a", "type": "Scalar" } ], "outputs": [ - { "type": "Tensor" } + { "type": "Scalar" } ] }, { - "name": "aten::sqrt_", + "name": "aten::sqrt.complex", "inputs": [ - { "name": "self", "type": "Tensor" } + { "name": "a", "type": "complex" } ], "outputs": [ - { "type": "Tensor" } + { "type": "complex" } ] }, { - "name": "aten::square", + "name": "aten::sqrt.float", "inputs": [ - { "name": "self", "type": "Tensor" } + { "name": "a", "type": "float32" } ], "outputs": [ - { "type": "Tensor" } + { "type": "float32" } ] }, { - "name": "aten::square.out", + "name": "aten::sqrt.int", "inputs": [ - { "name": "self", "type": "Tensor" } + { "name": "a", "type": "int64" } ], "outputs": [ - { "type": "Tensor" } + { "type": "float32" } ] }, { - "name": "aten::square_", + "name": "aten::sqrt.out", + "inputs": [ + { "name": "self", "type": "Tensor" } + ], + "outputs": [ + { "type": "Tensor" } + ] + }, + { + "name": "aten::sqrt_", + "inputs": [ + { "name": "self", "type": "Tensor" } + ], + "outputs": [ + { "type": "Tensor" } + ] + }, + { + "name": "aten::square", + "inputs": [ + { "name": "self", "type": "Tensor" } + ], + "outputs": [ + { "type": "Tensor" } + ] + }, + { + "name": "aten::square.out", + "inputs": [ + { "name": "self", "type": "Tensor" } + ], + "outputs": [ + { "type": "Tensor" } + ] + }, + { + "name": "aten::square_", "inputs": [ { "name": "self", "type": "Tensor" } ], @@ -15113,6 +16013,69 @@ { "type": "Tensor" } ] }, + { + "name": "aten::values", + "inputs": [ + { "name": "self", "type": "Tensor" } + ], + "outputs": [ + { "type": "Tensor" } + ] + }, + { + "name": "aten::values.Tensor", + "inputs": [ + { "name": "self", "type": "Dict(Tensor, t)" } + ], + "outputs": [ + { "type": "t[]" } + ] + }, + { + "name": "aten::values.bool", + "inputs": [ + { "name": "self", "type": "Dict(boolean, t)" } + ], + "outputs": [ + { "type": "t[]" } + ] + }, + { + "name": "aten::values.complex", + "inputs": [ + { "name": "self", "type": "Dict(complex, t)" } + ], + "outputs": [ + { "type": "t[]" } + ] + }, + { + "name": "aten::values.float", + "inputs": [ + { "name": "self", "type": "Dict(float32, t)" } + ], + "outputs": [ + { "type": "t[]" } + ] + }, + { + "name": "aten::values.int", + "inputs": [ + { "name": "self", "type": "Dict(int64, t)" } + ], + "outputs": [ + { "type": "t[]" } + ] + }, + { + "name": "aten::values.str", + "inputs": [ + { "name": "self", "type": "Dict(string, t)" } + ], + "outputs": [ + { "type": "t[]" } + ] + }, { "name": "aten::var", "inputs": [ @@ -15415,6 +16378,14 @@ { "type": "t" } ] }, + { + "name": "aten::warn", + "inputs": [ + { "name": "message", "type": "string" }, + { "name": "stacklevel", "type": "int64", "default": 2 } + ], + "outputs": [] + }, { "name": "aten::where", "inputs": [ @@ -15810,6 +16781,42 @@ { "type": "Tensor" } ] }, + { + "name": "prim::abs.Scalar", + "inputs": [ + { "name": "a", "type": "Scalar" } + ], + "outputs": [ + { "type": "Scalar" } + ] + }, + { + "name": "prim::abs.complex", + "inputs": [ + { "name": "a", "type": "complex" } + ], + "outputs": [ + { "type": "float32" } + ] + }, + { + "name": "prim::abs.float", + "inputs": [ + { "name": "a", "type": "float32" } + ], + "outputs": [ + { "type": "float32" } + ] + }, + { + "name": "prim::abs.int", + "inputs": [ + { "name": "a", "type": "int64" } + ], + "outputs": [ + { "type": "int64" } + ] + }, { "name": "prim::data", "inputs": [ @@ -15999,6 +17006,15 @@ { "type": "boolean" } ] }, + { + "name": "prim::isinstance", + "inputs": [ + { "name": "to_check", "type": "Any" } + ], + "outputs": [ + { "type": "boolean" } + ] + }, { "name": "prim::itemsize", "inputs": [ @@ -16017,6 +17033,220 @@ { "type": "Layout" } ] }, + { + "name": "prim::max", + "inputs": [ + { "name": "a", "type": "Scalar" }, + { "name": "b", "type": "Scalar" } + ], + "outputs": [ + { "type": "Scalar" } + ] + }, + { + "name": "prim::max.bool_list", + "inputs": [ + { "name": "l", "type": "boolean[]" }, + { "name": "r", "type": "boolean[]" } + ], + "outputs": [ + { "type": "boolean[]" } + ] + }, + { + "name": "prim::max.float", + "inputs": [ + { "name": "a", "type": "float32" }, + { "name": "b", "type": "float32" } + ], + "outputs": [ + { "type": "float32" } + ] + }, + { + "name": "prim::max.float_int", + "inputs": [ + { "name": "a", "type": "float32" }, + { "name": "b", "type": "int64" } + ], + "outputs": [ + { "type": "float32" } + ] + }, + { + "name": "prim::max.float_list", + "inputs": [ + { "name": "l", "type": "float32[]" }, + { "name": "r", "type": "float32[]" } + ], + "outputs": [ + { "type": "float32[]" } + ] + }, + { + "name": "prim::max.int", + "inputs": [ + { "name": "a", "type": "int64" }, + { "name": "b", "type": "int64" } + ], + "outputs": [ + { "type": "int64" } + ] + }, + { + "name": "prim::max.int_float", + "inputs": [ + { "name": "a", "type": "int64" }, + { "name": "b", "type": "float32" } + ], + "outputs": [ + { "type": "float32" } + ] + }, + { + "name": "prim::max.int_list", + "inputs": [ + { "name": "l", "type": "int64[]" }, + { "name": "r", "type": "int64[]" } + ], + "outputs": [ + { "type": "int64[]" } + ] + }, + { + "name": "prim::max.self_bool", + "inputs": [ + { "name": "self", "type": "boolean[]" } + ], + "outputs": [ + { "type": "boolean" } + ] + }, + { + "name": "prim::max.self_float", + "inputs": [ + { "name": "self", "type": "float32[]" } + ], + "outputs": [ + { "type": "float32" } + ] + }, + { + "name": "prim::max.self_int", + "inputs": [ + { "name": "self", "type": "int64[]" } + ], + "outputs": [ + { "type": "int64" } + ] + }, + { + "name": "prim::min", + "inputs": [ + { "name": "a", "type": "Scalar" }, + { "name": "b", "type": "Scalar" } + ], + "outputs": [ + { "type": "Scalar" } + ] + }, + { + "name": "prim::min.bool_list", + "inputs": [ + { "name": "l", "type": "boolean[]" }, + { "name": "r", "type": "boolean[]" } + ], + "outputs": [ + { "type": "boolean[]" } + ] + }, + { + "name": "prim::min.float", + "inputs": [ + { "name": "a", "type": "float32" }, + { "name": "b", "type": "float32" } + ], + "outputs": [ + { "type": "float32" } + ] + }, + { + "name": "prim::min.float_int", + "inputs": [ + { "name": "a", "type": "float32" }, + { "name": "b", "type": "int64" } + ], + "outputs": [ + { "type": "float32" } + ] + }, + { + "name": "prim::min.float_list", + "inputs": [ + { "name": "l", "type": "float32[]" }, + { "name": "r", "type": "float32[]" } + ], + "outputs": [ + { "type": "float32[]" } + ] + }, + { + "name": "prim::min.int", + "inputs": [ + { "name": "a", "type": "int64" }, + { "name": "b", "type": "int64" } + ], + "outputs": [ + { "type": "int64" } + ] + }, + { + "name": "prim::min.int_float", + "inputs": [ + { "name": "a", "type": "int64" }, + { "name": "b", "type": "float32" } + ], + "outputs": [ + { "type": "float32" } + ] + }, + { + "name": "prim::min.int_list", + "inputs": [ + { "name": "l", "type": "int64[]" }, + { "name": "r", "type": "int64[]" } + ], + "outputs": [ + { "type": "int64[]" } + ] + }, + { + "name": "prim::min.self_bool", + "inputs": [ + { "name": "self", "type": "boolean[]" } + ], + "outputs": [ + { "type": "boolean" } + ] + }, + { + "name": "prim::min.self_float", + "inputs": [ + { "name": "self", "type": "float32[]" } + ], + "outputs": [ + { "type": "float32" } + ] + }, + { + "name": "prim::min.self_int", + "inputs": [ + { "name": "self", "type": "int64[]" } + ], + "outputs": [ + { "type": "int64" } + ] + }, { "name": "prim::name", "inputs": [ @@ -16053,6 +17283,15 @@ { "type": "boolean" } ] }, + { + "name": "prim::shape", + "inputs": [ + { "name": "self", "type": "Tensor" } + ], + "outputs": [ + { "type": "int64[]" } + ] + }, { "name": "prim::type", "inputs": [ diff --git a/source/pytorch.js b/source/pytorch.js index f6db6b9d1d..d46c33a022 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -122,6 +122,8 @@ pytorch.Graph = class { initializers.set(obj, obj); } queue.push(obj); + } else if (pytorch.Utility.isInstance(obj, 'torch.Value') || pytorch.Utility.isInstance(obj, 'torch.Node')) { + continue; } else if (obj && obj.__class__) { obj.__parent__ = module; obj.__name__ = obj.__name__ || key; @@ -145,18 +147,29 @@ pytorch.Graph = class { node === graph.return_node()) { continue; } - if (node.kind() === 'prim::ListConstruct' && - node.outputs().length === 1 && - node.outputs().every((output) => output.uses().length === 1) && - node.inputs().every((input) => pytorch.Utility.isTensor(input.value))) { - continue; - } - if (node.kind() === 'prim::ListConstruct' && + if (node.kind() === 'prim::TupleConstruct' && node.inputs().length === 0 && node.outputs().length === 1 && node.outputs().every((output) => output.uses().length === 0)) { continue; } + if (node.kind() === 'prim::ListConstruct') { + if (node.outputs().length === 1 && + node.outputs().every((output) => output.uses().length === 1) && + node.inputs().every((input) => pytorch.Utility.isTensor(input.value))) { + continue; + } + if (node.inputs().length === 0 && + node.outputs().length === 1 && + node.outputs().every((output) => output.uses().length === 0)) { + continue; + } + if (node.inputs().every((value) => value && (pytorch.Utility.isInstance(value.type(), 'torch.IntType') || pytorch.Utility.isInstance(value.type(), 'torch.FloatType') || pytorch.Utility.isInstance(value.type(), 'torch.StringType') || pytorch.Utility.isInstance(value.type(), 'torch.ComplexType'))) && + node.outputs().length === 1 && + node.outputs().every((output) => output.uses().length === 1)) { + continue; + } + } if (node.kind() === 'prim::ListUnpack' && node.inputs().length === 1 && node.inputs().every((input) => input.uses().length === 1) && @@ -450,6 +463,30 @@ pytorch.Node = class { const value = values.map(identifier); argument = new pytorch.Argument(name, [value]); } + } else if (pytorch.Utility.isInstance(input, 'torch.Value') && !pytorch.Utility.isTensor(input.value)) { + if (input.node() === null && input.value !== undefined) { + argument = new pytorch.Argument(name, input.value, 'attribute'); + } else if (pytorch.Utility.isInstance(input.type(), 'torch.ListType')) { + if (input.node() && input.node().kind() === 'prim::ListConstruct' && input.uses().length === 1 && + input.node().inputs().every((value) => pytorch.Utility.isInstance(value.type(), 'torch.IntType') || pytorch.Utility.isInstance(value.type(), 'torch.FloatType') || pytorch.Utility.isInstance(value.type(), 'torch.StringType') || pytorch.Utility.isInstance(value.type(), 'torch.ComplexType') || pytorch.Utility.isInstance(value.type(), 'torch.TensorType'))) { + const list = input.node().inputs(); + const args = list.map((value) => { + if (value.uses().length === 1 && value.node() === input.node() && value.value !== undefined) { + return value.value; + } + const identifier = value.unique().toString(); + return values.map(identifier); + }); + argument = new pytorch.Argument(name, args, pytorch.Utility.toType(input.type())); + } else { + const identifier = input.unique().toString(); + argument = new pytorch.Argument(name, [values.map(identifier)]); + } + } else { + const identifier = input.unique().toString(); + const value = values.map(identifier); + argument = new pytorch.Argument(name, [value]); + } } else if (pytorch.Utility.isTensor(input.value) || input.value === undefined || input.value === null) { let list = [input]; if (input.node() && @@ -473,10 +510,6 @@ pytorch.Node = class { return values.map(identifier); }); argument = new pytorch.Argument(name, args); - } else if (pytorch.Utility.isInstance(input, 'torch.Value')) { - const identifier = input.unique().toString(); - const value = values.map(identifier); - argument = new pytorch.Argument(name, [value]); } else if (Array.isArray(input.value) && input.value.some((value) => pytorch.Utility.isInstance(value, 'torch.Value'))) { const args = input.value.map((value) => { if (pytorch.Utility.isInstance(value, 'torch.Value')) { @@ -1154,6 +1187,7 @@ pytorch.Container.Zip = class extends pytorch.Container { let torchscript = reader.has_record('constants.pkl'); const version = reader.version(); if (torchscript) { + execution.trace = false; const module = torch.jit.load(reader); execution.trace = true; if (module.data && module.data.forward) { @@ -1222,6 +1256,7 @@ pytorch.Container.ModelJson = class extends pytorch.Container { this.producer = this._model.producerName + (this._model.producerVersion ? ` v${this._model.producerVersion}` : ''); } this.format = reader.has_record('attributes.pkl') ? 'TorchScript v1.1' : 'TorchScript v1.0'; + execution.false = true; const module = torch.jit.load(reader); execution.trace = true; if (module.data && module.data.forward) { @@ -1874,11 +1909,13 @@ pytorch.Execution = class extends python.Execution { case 't': case 't1': case 't2': type = new torch.Type(); break; case 'Future(t)': type = new torch.FutureType(new torch.Type()); break; case 'AnyClassType': type = null; break; - case 'Dict[string,t]': type = new torch.DictType(new torch.StringType(), new torch.Type()); break; - case 'Dict[int64,t]': type = new torch.DictType(new torch.IntType(), new torch.Type()); break; - case 'Dict[float32,t]': type = new torch.DictType(new torch.FloatType(), new torch.Type()); break; - case 'Dict[boolean,t]': type = new torch.DictType(new torch.BoolType(), new torch.Type()); break; - case 'Dict[Tensor,t]': type = new torch.DictType(new torch.TensorType(), new torch.Type()); break; + case 'Dict(string, t)': type = new torch.DictType(new torch.StringType(), new torch.Type()); break; + case 'Dict(string, Tensor)': type = new torch.DictType(new torch.StringType(), new torch.TensorType()); break; + case 'Dict(int64, t)': type = new torch.DictType(new torch.IntType(), new torch.Type()); break; + case 'Dict(float32, t)': type = new torch.DictType(new torch.FloatType(), new torch.Type()); break; + case 'Dict(boolean, t)': type = new torch.DictType(new torch.BoolType(), new torch.Type()); break; + case 'Dict(complex, t)': type = new torch.DictType(new torch.ComplexType(), new torch.Type()); break; + case 'Dict(Tensor, t)': type = new torch.DictType(new torch.TensorType(), new torch.Type()); break; default: { if (arg.type.startsWith('__torch__.')) { type = new torch.ClassType(arg.type); @@ -2261,6 +2298,288 @@ pytorch.jit.Execution = class extends pytorch.Execution { return super.target(expression, context); } + expression(expression, context) { + if (!this.trace) { + return super.expression(expression, context); + } + const torch = this.torch; + switch (expression.type) { + case '=': { + const target = expression.target; + if (target.type === 'id') { + let value = this.expression(expression.expression, context); + if (typeof value === 'string') { + const node = this._graph.create('prim::Constant'); + const input = new torch.Value(node); + input.value = value; + node.addInput(input); + value = node.addOutput(); + value.setType(new torch.StringType()); + } + context.set(target.value, value); + return undefined; + } else if (target.type === 'tuple') { + context.target.push(target.value); + const value = this.expression(expression.expression, context); + context.target.pop(); + if (target.value.every((item) => item.type === 'id')) { + if (value instanceof torch.Value) { + const node = this._graph.create('prim::TupleUnpack'); + node.addInput(value); + const outputs = []; + for (let i = 0; i < target.value.length; i++) { + const item = target.value[i]; + const output = node.addOutput(); + const type = value.type(); + if (type instanceof torch.ListType) { + output.setType(value.type().getElementType()); + } else if (type instanceof torch.TupleType) { + output.setType(type.elements()[i]); + } else { + throw new pytorch.Error(`Unsupported tuple unpack type '${type.kind()}'.`); + } + output.setDebugName(item.value); + context.set(item.value, output); + outputs.push(output); + } + return outputs; + } + if (target.value.length < value.length) { + throw new python.Error(`ValueError: too many values to unpack (expected ${target.value.length}, actual ${value.length}).`); + } + if (target.value.length > value.length) { + throw new python.Error(`ValueError: not enough values to unpack (expected ${target.value.length}, actual ${value.length}).`); + } + for (let i = 0; i < value.length; i++) { + context.set(target.value[i].value, value[i]); + } + return undefined; + } + } + break; + } + case 'call': { + if (expression.target.type === 'id' && expression.target.value === 'annotate') { + let value = this.expression(expression.args[1], context); + const type = this.type(expression.args[0]); + if (value instanceof torch.Tensor) { + let name = null; + if (type instanceof torch.IntType) { + name = 'aten::IntImplicit'; + } else if (type instanceof torch.FloatType) { + name = 'aten::FloatImplicit'; + } else if (type instanceof torch.StringType) { + name = 'aten::StringImplicit'; + } else if (type instanceof torch.ComplexType) { + name = 'aten::ComplexImplicit'; + } else if (type instanceof torch.NumberType) { + name = 'aten::ScalarImplicit'; + } else { + throw new pytorch.Error(`Unsupported annotation type '${type.kind()}'.`); + } + const node = this._graph.create(name); + node.addInput(this.variable(value, node)); + value = node.addOutput(); + } + if (value instanceof torch.Value) { + value.setType(type); + } + return value; + } + if (expression.target.type === 'id' && expression.target.value === 'unchecked_cast') { + let value = this.expression(expression.args[1], context); + const type = this.type(expression.args[0]); + const node = this._graph.create('prim::unchecked_cast'); + node.addInput(this.variable(value)); + value = node.addOutput(); + value.setType(type); + return value; + } + if (expression.target.type === 'id' && expression.target.value === 'isinstance') { + let value = this.expression(expression.args[1], context); + // const type = this.type(expression.args[0]); + const node = this._graph.create('prim::isinstance'); + node.addInput(this.variable(value)); + value = node.addOutput(); + value.setType(new torch.BoolType()); + return value; + } + return super.expression(expression, context); + } + case '[]': { + if (expression.arguments.type === 'list' && expression.arguments.value.length === 1) { + const target = this.expression(expression.target, context); + if (target instanceof torch.Value && target.type() instanceof torch.ListType) { + let index = this.expression(expression.arguments.value[0], context); + const node = this._graph.create('aten::__getitem__.t'); + node.addInput(target); + if (Number.isInteger(index)) { + const value = this.invoke('torch.Value', [node]); + value.value = index; + index = value; + } + node.addInput(index); + const value = node.addOutput(); + value.setType(target.type().getElementType()); + return value; + } + if (target instanceof torch.Value && target.type() instanceof torch.DictType) { + let key = this.expression(expression.arguments.value[0], context); + const node = this._graph.create('aten::__getitem__.t'); + node.addInput(target); + if (target.type().getKeyType() instanceof torch.StringType && typeof key === 'string') { + const value = this.invoke('torch.Value', [node]); + value.value = key; + key = value; + } else { + throw new pytorch.Error(`Unsupported dictionary key type.`); + } + node.addInput(key); + const value = node.addOutput(); + value.setType(target.type().getValueType()); + return value; + } + if (target instanceof torch.Value && target.type() instanceof torch.TupleType) { + let index = this.expression(expression.arguments.value[0], context); + const node = this._graph.create('prim::TupleIndex'); + const value = node.addOutput(); + value.setType(target.type().elements()[index]); + node.addInput(target); + if (Number.isInteger(index)) { + const value = this.invoke('torch.Value', [node]); + value.value = index; + index = value; + } + node.addInput(index); + return value; + } + } + break; + } + case '.': { + if (expression.member.type === 'id') { + const target = this.target(expression.target, context); + if (typeof expression.member.value === 'string' && target instanceof torch.Value && target.type() instanceof torch.ClassType) { + const attribute = target.type().findAttribute(expression.member.value); + const node = this.graph.create('prim::GetAttr'); + const name = new torch.Value(node); + name.setType(new torch.StringType()); + name.value = expression.member.value; + node.addInput(target); + node.addInput(name); + const value = node.addOutput(); + value.setType(attribute); + return value; + } + return target[expression.member.value]; + } + throw new python.Error("Unsupported field expression."); + } + case 'list': { + const list = expression.value.map((item) => this.expression(item, context)); + if (/* list.length > 0 && */ list.every((item) => pytorch.Utility.isInstance(item, 'torch.Value') || Number.isInteger(item) || typeof item === 'string' || item === null)) { + const node = this._graph.create('prim::ListConstruct'); + const output = node.addOutput(); + for (const item of list) { + if (item instanceof torch.Value) { + node.addInput(item); + output.setType(new torch.ListType(item.type())); + } else if (Number.isInteger(item)) { + const value = new torch.Value(node); + value.value = item; + value.setType(new torch.IntType()); + node.addInput(value); + output.setType(new torch.ListType(new torch.IntType())); + } else if (typeof item === 'string') { + const value = new torch.Value(node); + value.value = item; + value.setType(new torch.StringType()); + node.addInput(value); + output.setType(new torch.ListType(new torch.StringType())); + } else { + const value = new torch.Value(node); + value.value = item; + node.addInput(value); + } + } + return output; + } + break; + } + case 'tuple': { + const args = expression.value.map((expression) => this.expression(expression, context)); + const node = this._graph.create('prim::TupleConstruct'); + const types = []; + const elements = []; + for (const item of args) { + if (item instanceof torch.Value) { + node.addInput(item); + types.push(item.type()); + elements.push(item); + } else if (pytorch.Utility.isTensor(item)) { + this.variable(item, node); + types.push(new torch.TensorType()); + elements.push(item); + } else if (Number.isInteger(item)) { + const value = new torch.Value(node); + value.value = item; + types.push(new torch.IntType()); + elements.push(item); + } else if (typeof item === 'boolean') { + const value = new torch.Value(node); + value.value = item; + types.push(new torch.BoolType()); + elements.push(item); + } else if (item === null) { + const value = new torch.Value(node); + value.value = item; + types.push(new torch.NoneType()); + elements.push(item); + } else { + const value = new torch.Value(node); + value.value = item; + types.push(new torch.Type()); + elements.push(item); + } + } + const value = node.addOutput(); + value.value = elements; + value.setType(new torch.TupleType(types)); + return value; + } + default: { + break; + } + } + return super.expression(expression, context); + } + + statement(statement, context) { + if (!this.trace) { + return super.statement(statement, context); + } + const torch = this.torch; + switch (statement.type) { + case 'class': { + super.statement(statement, context); + const value = context.get(statement.name); + const type = new torch.ClassType(`${value.__module__}.${value.__name__}`); + for (const entry of statement.body.statements) { + if (entry.type === 'var') { + const variableType = this.type(entry.variableType); + type.addAttribute(entry.name, variableType); + } + } + value.__type__ = type; + return undefined; + } + default: { + break; + } + } + return super.statement(statement, context); + } + type(expression) { const torch = this.torch; if (expression.type === '[]' && expression.target.type === 'id') { @@ -2274,8 +2593,8 @@ pytorch.jit.Execution = class extends pytorch.Execution { return new torch.OptionalType(elementType); } case 'Tuple': { - const args = expression.arguments.value.map((expression) => this.type(expression)); - return new torch.TupleType(args); + const elements = expression.arguments.value.map((expression) => this.type(expression)); + return new torch.TupleType(elements); } case 'Dict': { const key = this.type(expression.arguments.value[0]); @@ -2294,6 +2613,7 @@ pytorch.jit.Execution = class extends pytorch.Execution { case 'str': return new torch.StringType(); case 'float': return new torch.FloatType(); case 'number': return new torch.NumberType(); + case 'bool': return new torch.BoolType(); default: throw new pytorch.Error(`Unsupported type expression '${expression.value}'.`); } } @@ -2301,381 +2621,570 @@ pytorch.jit.Execution = class extends pytorch.Execution { } call(target, name, args, context) { - if (this.trace) { - const overload = this._overload(target, name, args, context); - if (overload) { - const [schema, args, evalArgs] = overload; - const copyArgs = Array.prototype.slice.call(args); - const copyEvalArgs = Array.prototype.slice.call(evalArgs); - const node = this._graph.create(schema.name); - node.schema = schema; - const referencedParameters = []; - const parameters = Array.prototype.slice.call(schema.inputs || []).concat(Array.prototype.slice.call(schema.attributes || [])); - while (copyEvalArgs.length > 0) { - if (parameters.length <= 0) { - if (schema.name.startsWith('_caffe2::')) { - break; - } + if (!this.trace) { + return super.call(target, name, args, context); + } + if (name === '__new__') { + const identifier = pytorch.Utility.target(target); + if (identifier) { + const type = this.resolve(identifier); + if (type && type.__type__) { + const node = this.graph.create('prim::CreateObject'); + const value = node.addOutput(); + value.setType(type.__type__); + return value; + } + } + } + if (name === '__init__') { + const obj = this.expression(target, context); + if (args.length === 0) { + return obj; + } + const node = this.graph.create('prim::CallMethod'); + node.addInput(obj); + const evalArgs = args.map((arg) => this.expression(arg, context)); + for (const arg of evalArgs) { + this.variable(arg, node); + } + const value = node.addOutput(); + value.setType(obj.type()); + return value; + } + const overload = this._overload(target, name, args, context); + if (!overload) { + return super.call(target, name, args, context); + } + const torch = this.torch; + const [schema, evalArgs] = overload; + const copyArgs = Array.prototype.slice.call(args); + const copyEvalArgs = Array.prototype.slice.call(evalArgs); + const node = this._graph.create(schema.name); + node.schema = schema; + const referencedParameters = []; + const parameters = Array.prototype.slice.call(schema.inputs || []).concat(Array.prototype.slice.call(schema.attributes || [])); + while (copyEvalArgs.length > 0) { + if (parameters.length <= 0) { + if (schema.name.startsWith('_caffe2::')) { + break; + } + if (schema.is_vararg) { + break; + } + throw new pytorch.Error(); + } + if (copyArgs.every((arg) => arg.type === '=' && arg.target && arg.target.type === 'id') && + parameters.every((parameter) => parameter.type !== 'Tensor' && parameter.type !== 'Tensor[]')) { + const map = new Map(parameters.map((parameter) => [parameter.name, parameter])); + while (copyArgs.length > 0) { + const argument = copyArgs.shift(); + const arg = copyEvalArgs.shift(); + const parameter = map.get(argument.target.value); + if (!parameter) { throw new pytorch.Error(); } - if (copyArgs.every((arg) => arg.type === '=' && arg.target && arg.target.type === 'id') && - parameters.every((parameter) => parameter.type !== 'Tensor' && parameter.type !== 'Tensor[]')) { - const map = new Map(parameters.map((parameter) => [parameter.name, parameter])); - while (copyArgs.length > 0) { - const argument = copyArgs.shift(); - const arg = copyEvalArgs.shift(); - const parameter = map.get(argument.target.value); - if (!parameter) { - throw new pytorch.Error(); - } - if (!this.isType(arg, parameter.type)) { - if (parameter.optional) { - continue; - } - throw new pytorch.Error(); - } - const value = this.variable(arg); - value.value = arg; - node.addInput(value); + if (!this.isType(arg, parameter.type)) { + if (parameter.optional) { + continue; } - continue; + throw new pytorch.Error(); } - const parameter = parameters.shift(); - const [argument] = copyEvalArgs; - if (parameter.type === 'Tensor' || (parameter.type === 'Scalar' && pytorch.Utility.isTensor(argument))) { - if (Array.isArray(argument) || (!pytorch.Utility.isTensor(argument) && argument !== null && argument !== undefined)) { - if (parameter.optional) { - continue; + const value = this.variable(arg); + value.value = arg; + node.addInput(value); + } + continue; + } + const parameter = parameters.shift(); + const [argument] = copyEvalArgs; + if (parameter.optional === true && + (parameter.type === 'float32' || parameter.type === 'boolean' || parameter.type === 'int64' || parameter.type === 'complex') && + argument instanceof torch.Value && argument.type() instanceof torch.NoneType) { + copyArgs.shift(); + copyEvalArgs.shift(); + node.addInput(argument); + } else if (parameter.type === 'Tensor[]') { + const [argument] = copyEvalArgs; + if ((argument instanceof torch.Value && this.fromType(argument.type()) === 'Tensor[]') || + (Array.isArray(argument) && argument.every((item) => pytorch.Utility.isTensor(item) || item === null || (item instanceof torch.Value && item.type() instanceof torch.TensorType)))) { + copyArgs.shift(); + copyEvalArgs.shift(); + if (argument instanceof torch.Value) { + node.addInput(argument); + } else { + const list = this._graph.create('prim::ListConstruct'); + for (const arg of argument) { + const tensor = arg; + if (tensor) { + tensor.__count__ = (tensor.__count__ || 0) + 1; } - throw new pytorch.Error(); - } else { - copyArgs.shift(); - copyEvalArgs.shift(); - const tensor = (argument === null || argument === undefined) ? {} : argument; const value = this.variable(tensor); - referencedParameters.push(tensor); - node.addInput(value); + value.setType(new torch.TensorType()); + list.addInput(value); } - } else if (parameter.type === 'Tensor[]') { - const [argument] = copyEvalArgs; - if (!Array.isArray(argument) || !argument.every((item) => pytorch.Utility.isTensor(item) || item === null)) { - if (parameter.optional) { - continue; - } - throw new pytorch.Error(); - } else { - copyArgs.shift(); - copyEvalArgs.shift(); - - const list = this._graph.create('prim::ListConstruct'); - for (const arg of argument) { - const tensor = arg; - if (tensor) { - tensor.__count__ = (tensor.__count__ || 0) + 1; - } - const value = this.variable(tensor); - list.addInput(value); - } - - const value = list.addOutput(); - node.addInput(value); + const value = list.addOutput(); + value.setType(new torch.ListType(new torch.TensorType())); + node.addInput(value); + } + } else { + if (parameter.optional) { + continue; + } + throw new pytorch.Error(); + } + } else { + const [arg] = copyArgs; + if (!this.isType(argument, parameter.type) && argument !== null) { + if (parameter.optional) { + continue; + } + throw new pytorch.Error('Invalid argument type.'); + } else if (arg.type === '=') { + throw new pytorch.Error('Expected named argument.'); + } else { + copyArgs.shift(); + copyEvalArgs.shift(); + if (pytorch.Utility.isInstance(argument, 'torch.Value')) { + node.addInput(argument); + } else { + const value = this.variable(argument); + if (value instanceof torch.Value) { + // value.setType(this.toType(parameter.type)); } + node.addInput(value); + value.value = argument; + } + } + } + } + const result = []; + for (let i = 0; i < schema.outputs.length; i++) { + const parameter = schema.outputs[i]; + let type = parameter.type; + if (type === 't[]') { + const index = schema.inputs.findIndex((input) => input.type === parameter.type); + if (index === -1) { + const index = schema.inputs.findIndex((input) => input.type.match(/^Dict\(\w+, t\)$/)); + if (index === -1) { + throw new pytorch.Error("Unknown value type 't[]'."); } else { - const [arg] = copyArgs; - if (!this.isType(argument, parameter.type) && argument !== null) { - if (parameter.optional) { - continue; - } - throw new pytorch.Error(); - } else if (arg.type === '=') { - throw new pytorch.Error('Expected named argument.'); + const inputs = node.inputs(); + const value = inputs[index]; + if (value instanceof torch.Value && value.type() instanceof torch.DictType) { + type = this.fromType(new torch.ListType(value.type().getValueType())); + } else if (value.value && Object.values(value.value).every((item) => pytorch.Utility.isTensor(item))) { + type = 'Tensor[]'; } else { - copyArgs.shift(); - copyEvalArgs.shift(); - const value = this.variable(argument); - node.addInput(value); - value.value = argument; + throw new pytorch.Error("Unknown dict type 't[]'."); } } + } else { + const inputs = node.inputs(); + const value = inputs[index]; + if (value instanceof torch.Value && value.type() instanceof torch.ListType) { + type = this.fromType(value.type()); + } else if (Array.isArray(value) && value.length > 0 && value.every((item) => Number.isInteger(item))) { + type = 'int64[]'; + } else if (value.value && Array.isArray(value.value) && value.value.length > 0 && value.value.every((item) => Number.isInteger(item) || isNaN(item))) { + type = 'int64[]'; + } else if (value.value && Array.isArray(value.value) && value.value.length > 0 && value.value.every((item) => pytorch.Utility.isTensor(item))) { + type = 'Tensor[]'; + } else { + throw new pytorch.Error("Unknown value type 't[]'."); + } } - const result = []; - for (let i = 0; i < schema.outputs.length; i++) { - const parameter = schema.outputs[i]; - switch (parameter.type) { - case 'Scalar': - case 'Tensor': { - const output = this.invoke('torch.Tensor', []); - output.__origin__ = schema.name; - if (i === 0) { - switch (schema.name) { - case 'aten::conv1d': - case 'aten::embedding': { - output.resize_([NaN, NaN, NaN]); - break; - } - case 'aten::cat': - case 'aten::conv2d': - case 'aten::dropout': - case 'aten::flatten': - case 'aten::flatten.named_out_dim': - case 'aten::max_pool2d': - case 'aten::adaptive_avg_pool2d': - case 'aten::avg_pool2d': - case 'aten::quantize_per_tensor': - case 'aten::relu_': - case 'aten::prelu': - case 'aten::hardtanh_': - case 'aten::upsample_bilinear2d': - case 'prepacked::conv2d_clamp_run': { - const [input] = evalArgs; - if (pytorch.Utility.isTensor(input) && input.size() === undefined) { - input.resize_([NaN, NaN, NaN, NaN]); - } - output.resize_([NaN, NaN, NaN, NaN]); - break; - } - case 'aten::slice': - case 'aten::slice.Tensor': { - const [input] = evalArgs; - if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { - const size = input.size(); - output.resize_(size); - } - break; - } - case 'aten::to': - case 'aten::to.device': - case 'aten::to.dtype': - case 'aten::to.dtype_layout': { - const [input] = evalArgs; - if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { - const size = input.size(); - output.resize_(size); - } - break; - } - case 'aten::conv3d': { - output.resize_([NaN, NaN, NaN, NaN, NaN]); - break; - } - case 'aten::roll': - case 'aten::detach': - case 'aten::mean': - case 'aten::mul': - case 'aten::mul.Scalar': - case 'aten::div': - case 'aten::div.Scalar': - case 'aten::batch_norm': - case 'aten::gelu': - case 'aten::relu': - case 'aten::clamp': - case 'aten::clamp_': - case 'aten::_add_relu_': - case 'aten::hardswish_': { - const [input] = evalArgs; - if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { - output.resize_(input.size()); - } - break; - } - case 'aten::add': - case 'aten::add.Scalar': - case 'aten::sub': - case 'aten::sub.Scalar': { - const [input] = evalArgs; - if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { - output.resize_(input.size()); - } else { - const [, other] = evalArgs; - if (pytorch.Utility.isTensor(other) && Array.isArray(other.size())) { - output.resize_(other.size()); - } - } - break; - } - case 'aten::select': - case 'aten::select.int': { - const [input] = evalArgs; - if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { - output.resize_(Array(input.size().length - 1).fill(NaN)); - } - break; - } - case 'aten::layer_norm': { - const [input, normalized_shape] = evalArgs; - if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { - const shape = input.size(); - if (Array.isArray(normalized_shape) && normalized_shape.length === 1) { - const [value] = normalized_shape; - shape[shape.length - 1] = value; - } - output.resize_(shape); - } - break; - } - case 'aten::empty': - case 'aten::ones': - case 'aten::zeros': - case 'aten::zeros_like': { - output.resize_(evalArgs[0]); - break; + } else if (type === 't') { + const index = schema.inputs.findIndex((input) => input.type === parameter.type); + const value = node.inputs()[index]; + if (value instanceof torch.Value && value.type()) { + type = value.type(); + if (type instanceof torch.ListType && type.getElementType() instanceof torch.IntType) { + type = 'int64[]'; + } + } else if (Array.isArray(value.value) && value.value.every((item) => Number.isInteger(item))) { + type = 'int64[]'; + } else if (Number.isInteger(value.value)) { + type = 'int64'; + } else if (pytorch.Utility.isTensor(value.value)) { + type = 'Tensor'; + } else { + throw new pytorch.Error("Unknown value type 't'."); + } + } + switch (type) { + case 'Tensor': { + const output = this.invoke('torch.Tensor', []); + output.__origin__ = schema.name; + if (i === 0) { + switch (schema.name) { + case 'aten::conv1d': + case 'aten::embedding': { + output.resize_([NaN, NaN, NaN]); + break; + } + case 'aten::cat': + case 'aten::conv2d': + case 'aten::dropout': + case 'aten::flatten': + case 'aten::flatten.named_out_dim': + case 'aten::max_pool2d': + case 'aten::adaptive_avg_pool2d': + case 'aten::avg_pool2d': + case 'aten::quantize_per_tensor': + case 'aten::relu_': + case 'aten::prelu': + case 'aten::hardtanh_': + case 'aten::upsample_bilinear2d': + case 'prepacked::conv2d_clamp_run': { + const [input] = evalArgs; + if (pytorch.Utility.isTensor(input) && input.size() === undefined) { + input.resize_([NaN, NaN, NaN, NaN]); + } + output.resize_([NaN, NaN, NaN, NaN]); + break; + } + case 'aten::slice': + case 'aten::slice.Tensor': { + const [input] = evalArgs; + if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { + const size = input.size(); + output.resize_(size); + } + break; + } + case 'aten::to': + case 'aten::to.device': + case 'aten::to.dtype': + case 'aten::to.dtype_layout': { + const [input] = evalArgs; + if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { + const size = input.size(); + output.resize_(size); + } + break; + } + case 'aten::conv3d': { + output.resize_([NaN, NaN, NaN, NaN, NaN]); + break; + } + case 'aten::roll': + case 'aten::detach': + case 'aten::mean': + case 'aten::mul': + case 'aten::mul.Scalar': + case 'aten::div': + case 'aten::div.Scalar': + case 'aten::batch_norm': + case 'aten::gelu': + case 'aten::relu': + case 'aten::clamp': + case 'aten::clamp_': + case 'aten::_add_relu_': + case 'aten::hardswish_': { + const [input] = evalArgs; + if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { + output.resize_(input.size()); + } + break; + } + case 'aten::add': + case 'aten::add.Scalar': + case 'aten::sub': + case 'aten::sub.Scalar': { + const [input] = evalArgs; + if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { + output.resize_(input.size()); + } else { + const [, other] = evalArgs; + if (pytorch.Utility.isTensor(other) && Array.isArray(other.size())) { + output.resize_(other.size()); } - case 'aten::view': - case 'aten::reshape': - case 'aten::new_full': { - output.resize_(evalArgs[1]); - break; + } + break; + } + case 'aten::select': + case 'aten::select.int': { + const [input] = evalArgs; + if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { + output.resize_(Array(input.size().length - 1).fill(NaN)); + } + break; + } + case 'aten::layer_norm': { + const [input, normalized_shape] = evalArgs; + if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { + const shape = input.size(); + if (Array.isArray(normalized_shape) && normalized_shape.length === 1) { + const [value] = normalized_shape; + shape[shape.length - 1] = value; } - case 'aten::squeeze': - case 'aten::squeeze.dim': { - const [input] = evalArgs; - const size = input.size(); - if (Array.isArray(size)) { - switch (evalArgs.length) { - case 1: { - output.resize_(size.filter((value) => value !== 1)); - break; - } - case 2: { - const [, dim] = evalArgs; - output.resize_(size.filter((value, index) => (value !== 1 && !isNaN(value)) || index !== dim)); - break; - } - default: { - break; - } - } + output.resize_(shape); + } + break; + } + case 'aten::empty': + case 'aten::ones': + case 'aten::zeros': + case 'aten::zeros_like': { + output.resize_(evalArgs[0]); + break; + } + case 'aten::view': + case 'aten::reshape': + case 'aten::new_full': { + output.resize_(evalArgs[1]); + break; + } + case 'aten::squeeze': + case 'aten::squeeze.dim': { + const [input] = evalArgs; + const size = input.size(); + if (Array.isArray(size)) { + switch (evalArgs.length) { + case 1: { + output.resize_(size.filter((value) => value !== 1)); + break; } - break; - } - case 'aten::unsqueeze': { - const [input, dim] = evalArgs; - const size = input.size(); - if (Array.isArray(size) && dim !== undefined) { - const shape = size.slice(); - shape.splice(dim, 0, 1); - output.resize_(shape); - } else { - output.resize_([NaN, NaN, NaN, NaN]); + case 2: { + const [, dim] = evalArgs; + output.resize_(size.filter((value, index) => (value !== 1 && !isNaN(value)) || index !== dim)); + break; } - break; - } - case 'aten::transpose': - case 'aten::transpose.int': { - const [input, dim0, dim1] = evalArgs; - if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { - const size = input.size().slice(); - const d0 = dim0 >= 0 ? dim0 : size.length + dim0; - const d1 = dim1 >= 0 ? dim1 : size.length + dim1; - const value = size[dim0]; - /* eslint-disable prefer-destructuring */ - size[d0] = size[1]; - /* eslint-enable prefer-destructuring */ - size[d1] = value; - output.resize_(size); + default: { + break; } - break; - } - case 'aten::contiguous': { - const [source] = evalArgs; - output.__source__ = source; - break; } - case 'quantized::cat': - case 'quantized::cat_relu': - case 'quantized::linear': - case 'quantized::conv2d': - case 'quantized::conv2d.new': - case 'quantized::conv2d_relu': - case 'quantized::conv2d_relu.new': - case 'quantized::add': - case 'quantized::add_relu': - output.resize_([NaN, NaN, NaN, NaN]); - output.__quantized__ = true; - break; - default: - break; } + break; } - this.variable(output, node); - result.push(output); - break; - } - case 'Tensor[]': { - let count = 1; - switch (schema.name) { - case 'aten::chunk': - count = node.inputs()[1].value; - break; - case 'aten::meshgrid': { - const list = node.inputs()[0].node(); - if (list.kind() === 'prim::ListConstruct') { - count = list.inputs().length; + case 'aten::unsqueeze': { + const [input, dim] = evalArgs; + if (pytorch.Utility.isTensor(input)) { + const size = input.size(); + if (Array.isArray(size) && dim !== undefined) { + const shape = size.slice(); + shape.splice(dim, 0, 1); + output.resize_(shape); + } else { + output.resize_([NaN, NaN, NaN, NaN]); } - break; } - case 'aten::unbind': - case 'aten::unbind.int': - count = args[0].__tuple__ || count; - break; - case 'aten::broadcast_tensors': - case 'aten::split': - case 'aten::split.Tensor': - case 'aten::split_with_sizes': - if (context.target.length > 0) { - count = context.target[context.target.length - 1].length; - } - break; - default: - break; + break; } - - const value = node.addOutput(); - const list = this._graph.create('prim::ListUnpack'); - list.addInput(value); - - const tensors = []; - for (let i = 0; i < count; i ++) { - const tensor = this.invoke('torch.Tensor', []); - tensor.__origin__ = schema.name; - this.variable(tensor, list); - tensors.push(tensor); + case 'aten::transpose': + case 'aten::transpose.int': { + const [input, dim0, dim1] = evalArgs; + if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { + const size = input.size().slice(); + const d0 = dim0 >= 0 ? dim0 : size.length + dim0; + const d1 = dim1 >= 0 ? dim1 : size.length + dim1; + const value = size[dim0]; + /* eslint-disable prefer-destructuring */ + size[d0] = size[1]; + /* eslint-enable prefer-destructuring */ + size[d1] = value; + output.resize_(size); + } + break; } - result.push(tensors); - break; + case 'aten::contiguous': { + const [source] = evalArgs; + output.__source__ = source; + break; + } + case 'quantized::cat': + case 'quantized::cat_relu': + case 'quantized::linear': + case 'quantized::conv2d': + case 'quantized::conv2d.new': + case 'quantized::conv2d_relu': + case 'quantized::conv2d_relu.new': + case 'quantized::add': + case 'quantized::add_relu': + output.resize_([NaN, NaN, NaN, NaN]); + output.__quantized__ = true; + break; + default: + break; } - case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase': - case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase': - case '__torch__.torch.classes.quantized.LinearPackedParamsBase': - case '__torch__.torch.classes.rnn.CellParamsBase': - case '__torch__.torch.classes.xnnpack.Conv2dOpContext': - case '__torch__.torch.classes.xnnpack.LinearOpContext': - case '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext': { - const value = this.invoke(parameter.type, []); - this.variable(value, node); - result.push(value); + } + this.variable(output, node); + result.push(output); + break; + } + case 'Tensor[]': { + let count = 1; + switch (schema.name) { + case 'aten::chunk': + count = node.inputs()[1].value; break; - } - default: { - const output = this.invoke('torch.Tensor', []); - output.resize_([]); - output.__origin__ = schema.name; - this.variable(output, node); - result.push(output); + case 'aten::meshgrid': { + const list = node.inputs()[0].node(); + if (list.kind() === 'prim::ListConstruct') { + count = list.inputs().length; + } break; } + case 'aten::unbind': + case 'aten::unbind.int': + count = args[0].__tuple__ || count; + break; + case 'aten::broadcast_tensors': + case 'aten::split': + case 'aten::split.Tensor': + case 'aten::split_with_sizes': + if (context.target.length > 0) { + count = context.target[context.target.length - 1].length; + } + break; + default: + break; + } + + const value = node.addOutput(); + value.setType(new torch.ListType(new torch.TensorType())); + result.push(value); + + /* + const value = node.addOutput(); + const list = this._graph.create('prim::ListUnpack'); + list.addInput(value); + + const tensors = []; + for (let i = 0; i < count; i ++) { + const tensor = this.invoke('torch.Tensor', []); + tensor.__origin__ = schema.name; + this.variable(tensor, list); + tensors.push(tensor); } + result.push(tensors); + */ + break; + } + case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase': + case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase': + case '__torch__.torch.classes.quantized.LinearPackedParamsBase': + case '__torch__.torch.classes.rnn.CellParamsBase': + case '__torch__.torch.classes.xnnpack.Conv2dOpContext': + case '__torch__.torch.classes.xnnpack.LinearOpContext': + case '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext': { + const value = this.invoke(parameter.type, []); + this.variable(value, node); + result.push(value); + break; } - for (const referencedParameter of referencedParameters) { - referencedParameter.__count__ = (referencedParameter.__count__ || 0) + 1; + case 'Scalar': { + const value = this.variable(null, node); + value.__origin__ = schema.name; + value.setType(new torch.NumberType()); + result.push(value); + break; + } + case 'boolean': { + const value = this.variable(null, node); + value.__origin__ = schema.name; + value.setType(new torch.BoolType()); + result.push(value); + break; } - if (result.length > 1) { - return result; + case 'boolean[]': { + const value = this.variable(null, node); + value.__origin__ = schema.name; + value.setType(new torch.ListType(new torch.BoolType())); + result.push(value); + break; + } + case 'string[]': { + const value = this.variable(null, node); + value.__origin__ = schema.name; + value.setType(new torch.ListType(new torch.StringType())); + result.push(value); + break; + } + case 'int64': { + const value = this.variable(null, node); + value.__origin__ = schema.name; + value.setType(new torch.IntType()); + switch (schema.name) { + case 'aten::div.int': value.value = torch.div(evalArgs[0], evalArgs[1]); break; + case 'aten::dim': value.value = torch.dim(evalArgs[0]); break; + case 'aten::len.t': value.value = torch.len(evalArgs[0]); break; + // case 'aten::size.int': value.value = torch.size(evalArgs[0], evalArgs[1]); break; + default: break; + } + result.push(value); + break; + } + case 'int64[]': { + const value = this.variable(null, node); + value.__origin__ = schema.name; + value.setType(new torch.ListType(new torch.IntType())); + switch (schema.name) { + // case 'aten::size': value.value = torch.size(evalArgs[0], evalArgs[1]); break; + default: break; + } + result.push(value); + break; + } + case 'float32': { + const value = this.variable(null, node); + value.__origin__ = schema.name; + value.setType(new torch.FloatType()); + result.push(value); + break; + } + case 'float32[]': { + const value = this.variable(null, node); + value.__origin__ = schema.name; + value.setType(new torch.ListType(new torch.FloatType())); + result.push(value); + break; + } + case 'complex': { + const value = this.variable(null, node); + value.__origin__ = schema.name; + value.setType(new torch.ComplexType()); + result.push(value); + break; + } + case 'string': { + const value = this.variable(null, node); + value.__origin__ = schema.name; + value.setType(new torch.StringType()); + result.push(value); + break; + } + case 'Dict(string, Tensor)': { + const value = this.variable(null, node); + value.__origin__ = schema.name; + value.setType(new torch.DictType(new torch.StringType(), new torch.TensorType())); + result.push(value); + break; + } + case 'Dict(string, t)': { + const value = this.variable(null, node); + value.__origin__ = schema.name; + value.setType(new torch.DictType(new torch.StringType(), new torch.TensorType())); // extract value tyoe + result.push(value); + break; + } + default: { + const output = this.invoke('torch.Tensor', []); + output.resize_([]); + output.__origin__ = schema.name; + this.variable(output, node); + result.push(output); + break; } - return result[0]; } } - return super.call(target, name, args, context); + for (const referencedParameter of referencedParameters) { + referencedParameter.__count__ = (referencedParameter.__count__ || 0) + 1; + } + if (result.length > 1) { + return result; + } + return result[0]; } isType(obj, type) { @@ -2685,7 +3194,8 @@ pytorch.jit.Execution = class extends pytorch.Execution { return !Array.isArray(obj) && (pytorch.Utility.isTensor(obj) || obj === null || (obj instanceof torch.Value && obj.type() instanceof torch.TensorType)); case 'Tensor[]': - return Array.isArray(obj) && obj.length > 0 && obj.every((tensor) => pytorch.Utility.isTensor(tensor) || tensor === null); + return Array.isArray(obj) && obj.length > 0 && + obj.every((tensor) => pytorch.Utility.isTensor(tensor) || tensor === null || (tensor instanceof torch.Value && tensor.type() instanceof torch.TensorType)); case 'Scalar': return (obj !== null && (obj !== Object(obj) || obj instanceof Number)) || (pytorch.Utility.isTensor(obj) && Array.isArray(obj.size()) && obj.size().length === 0) || @@ -2696,12 +3206,12 @@ pytorch.jit.Execution = class extends pytorch.Execution { if (Array.isArray(obj) && obj.every((item) => item === true || item === false)) { return true; } - if (pytorch.Utility.isInstance(obj, 'torch.Value') && pytorch.Utility.isInstance(obj.type(), 'torch.ListType') && pytorch.Utility.isInstance(obj.type().getElementType(), 'torch.BoolType')) { + if (obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.BoolType) { return true; } return false; case 'string': - return obj === null || typeof obj === 'string'; + return obj === null || typeof obj === 'string' || (obj instanceof torch.Value && obj.type() instanceof torch.StringType); case 'SymInt': case 'int64': return Number.isInteger(obj) || typeof obj === 'bigint' || @@ -2755,16 +3265,35 @@ pytorch.jit.Execution = class extends pytorch.Execution { case 'Device': return obj === null || obj === Object(obj); case 't[]': - return Array.isArray(obj) || (pytorch.Utility.isInstance(obj, 'torch.Value') && pytorch.Utility.isInstance(obj.type(), 'torch.ListType')); + return Array.isArray(obj) || (obj instanceof torch.Value && obj.type() instanceof torch.ListType); case 't': return true; case 'AnyEnumType': return false; + case 'complex': + return obj instanceof torch.Value && obj.type() instanceof torch.ComplexType; + case 'Any[]': + if (Array.isArray(obj)) { + return true; + } + if (obj instanceof torch.Value && obj.type() instanceof torch.ListType) { + return true; + } + return false; default: if (type && type.startsWith('__torch__.') && obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__) { return type === `${obj.__class__.__module__}.${obj.__class__.__name__}`; } + if (type.startsWith('Dict(') && type.endsWith(')') && obj instanceof torch.Value && obj.type() instanceof torch.DictType) { + const params = type.substring(5, type.length - 1).split(',').map((item) => item.trim()); + if ((params[0] === 't' || params[0] === this.fromType(obj.type().getKeyType())) && + (params[1] === 't' || params[1] === this.fromType(obj.type().getValueType))) { + return true; + } + return false; + } + // throw new pytorch.Error(`Unknown type '${type}'.`); return true; } } @@ -2783,191 +3312,253 @@ pytorch.jit.Execution = class extends pytorch.Execution { if (type instanceof torch.FloatType) { return 'float32'; } + if (type instanceof torch.BoolType) { + return 'boolean'; + } + if (type instanceof torch.StringType) { + return 'string'; + } throw new pytorch.Error(`Unknown type '${type.kind()}'.`); } _overload(target, name, args, context) { - let moduleName = pytorch.Utility.target(target); - if (moduleName) { - let outputTypes = null; - let type = name ? `${moduleName}.${name}` : moduleName; - if (type === 'ops.prim.NumToTensor' && args.length === 1 && args[0].type === 'call' && args[0].target.member.type === 'id') { - const [arg] = args; - moduleName = pytorch.Utility.target(arg.target.target); - name = arg.target.member.value; - args = arg.args; - outputTypes = ['int64']; - type = `${moduleName}.${name}`; - } - // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml - let overloads = null; - if (type.startsWith('torch.')) { - overloads = this._types.get(`aten::${type.substring(6)}`); - /* } else if (type.startsWith('ops.prim.')) { - overloads = this._types.get(`prim::${type.substring(9)}`); - } else if (type === 'int') { - overloads = this._types.get(`aten::Int`); - // "bool": "aten::Bool" - // "int": "aten::Int" - // "float": "aten::Float" - // "complex": "aten::Complex" - // "abs": "prim::abs" - // "max": "prim::max" - // "min": "prim::min" - // "range": "fake::does_not_exist" - */ - } else if (type.startsWith('ops.') && !type.startsWith('ops.prim.')) { - const path = type.split('.'); - if (path.length === 3) { - overloads = this._types.get(`${path[1]}::${path[2]}`); - } - if (!overloads) { - const module = this.import(moduleName); - if (!module || !module[name]) { - const metadata = {}; - metadata.name = type; - metadata.inputs = []; - metadata.outputs = []; - for (let i = 0; i < args.length; i++) { - const input = {}; - let argument = args[i]; - input.name = i.toString(); - if (argument.type === '=' && argument.target && argument.target.type === 'id') { - input.name = this.expression(argument.target, context); - argument = argument.expression; - } - const obj = this.expression(argument, context); - input.type = pytorch.Utility.getType(obj); - metadata.inputs.push(input); - } - const count = context.target.length > 0 ? context.target[context.target.length - 1].length : 0; - for (let i = 0; i < count; i++) { - metadata.outputs.push({ name: '', type: '' }); + const moduleName = pytorch.Utility.target(target); + if (!moduleName) { + return null; + } + const torch = this.torch; + const type = name ? `${moduleName}.${name}` : moduleName; + // const outputTypes = null; + // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml + let overloads = null; + if (type.startsWith('torch.')) { + overloads = this._types.get(`aten::${type.substring(6)}`); + } else if (type.startsWith('ops.prim.')) { + overloads = this._types.get(`prim::${type.substring(9)}`); + } else if (type === 'int') { + overloads = this._types.get(`aten::Int`); + } else if (type === 'str') { + overloads = this._types.get(`aten::str`); + // "bool": "aten::Bool" + // "int": "aten::Int" + // "float": "aten::Float" + // "complex": "aten::Complex" + // "range": "fake::does_not_exist" + } else if (type.startsWith('ops.') && !type.startsWith('ops.prim.')) { + const path = type.split('.'); + if (path.length === 3) { + overloads = this._types.get(`${path[1]}::${path[2]}`); + } + if (!overloads) { + const module = this.import(moduleName); + if (!module || !module[name]) { + const metadata = {}; + metadata.name = type; + metadata.inputs = []; + metadata.outputs = []; + for (let i = 0; i < args.length; i++) { + const input = {}; + let argument = args[i]; + input.name = i.toString(); + if (argument.type === '=' && argument.target && argument.target.type === 'id') { + input.name = this.expression(argument.target, context); + argument = argument.expression; } - this._metadata.add(type, metadata); - overloads = [metadata]; + const obj = this.expression(argument, context); + input.type = pytorch.Utility.getType(obj); + metadata.inputs.push(input); } + const count = context.target.length > 0 ? context.target[context.target.length - 1].length : 0; + for (let i = 0; i < count; i++) { + metadata.outputs.push({ name: '', type: '' }); + } + this._metadata.add(type, metadata); + overloads = [metadata]; } } - if (overloads) { - overloads = Array.isArray(overloads) ? overloads : [overloads]; - const evalArgs = args.map((argument) => { - if (argument.type === '=' && argument.target && argument.target.type === 'id') { - argument = argument.expression; - } - return this.expression(argument, context); - }); - for (const schema of overloads) { - const copyArgs = Array.prototype.slice.call(args); - const copyEvalArgs = Array.prototype.slice.call(evalArgs); - const parameters = Array.prototype.slice.call(schema.inputs || []).concat(Array.prototype.slice.call(schema.attributes || [])); - let next = false; - while (copyEvalArgs.length > 0) { - if (parameters.length <= 0) { - next = !schema.name.startsWith('_caffe2::'); + } + if (!overloads) { + if (type.startsWith('aten::') || type.startsWith('prim::')) { + throw new pytorch.Error(`Unknown function '${type}'.`); + } + return null; + } + overloads = Array.isArray(overloads) ? overloads : [overloads]; + const evalArgs = args.map((argument) => { + if (argument.type === '=' && argument.target && argument.target.type === 'id') { + argument = argument.expression; + } + return this.expression(argument, context); + }); + const matches = []; + for (const schema of overloads) { + const copyArgs = Array.prototype.slice.call(args); + const copyEvalArgs = Array.prototype.slice.call(evalArgs); + const parameters = Array.prototype.slice.call(schema.inputs || []).concat(Array.prototype.slice.call(schema.attributes || [])); + let next = false; + let kwarg_only = false; + while (copyEvalArgs.length > 0) { + if (parameters.length <= 0) { + next = !schema.name.startsWith('_caffe2::') && !schema.is_vararg; + break; + } + if (copyArgs.every((arg) => arg.type === '=' && arg.target && arg.target.type === 'id') && + parameters.every((parameter) => parameter.type !== 'Tensor' && parameter.type !== 'Tensor[]')) { + const map = new Map(parameters.map((parameter) => [parameter.name, parameter])); + while (copyArgs.length > 0) { + const argument = copyArgs.shift(); + const arg = copyEvalArgs.shift(); + const parameter = map.get(argument.target.value); + if (!parameter) { + next = true; break; } - if (copyArgs.every((arg) => arg.type === '=' && arg.target && arg.target.type === 'id') && - parameters.every((parameter) => parameter.type !== 'Tensor' && parameter.type !== 'Tensor[]')) { - const map = new Map(parameters.map((parameter) => [parameter.name, parameter])); - while (copyArgs.length > 0) { - const argument = copyArgs.shift(); - const arg = copyEvalArgs.shift(); - const parameter = map.get(argument.target.value); - if (!parameter) { - next = true; - break; - } - if (!this.isType(arg, parameter.type)) { - if (parameter.optional) { - continue; - } - next = true; - break; - } - } - continue; + if (parameter.kwarg_only) { + kwarg_only = true; } - if (next) { + if (!this.isType(arg, parameter.type)) { + if (parameter.optional) { + continue; + } + next = true; break; } - const parameter = parameters.shift(); - const [argument] = copyEvalArgs; - if (parameter.type === 'Tensor' || (parameter.type === 'Scalar' && pytorch.Utility.isTensor(argument))) { - if (Array.isArray(argument) || (!pytorch.Utility.isTensor(argument) && argument !== null && argument !== undefined)) { - if (parameter.optional) { - continue; - } - next = true; - } else { - copyArgs.shift(); - copyEvalArgs.shift(); - } - } else if (parameter.type === 'Tensor[]') { - const [argument] = copyEvalArgs; - if (!Array.isArray(argument) || !argument.every((item) => pytorch.Utility.isTensor(item) || item === null)) { - if (parameter.optional) { - continue; - } - next = true; - } else { - copyArgs.shift(); - copyEvalArgs.shift(); - } - } else { - const [arg] = copyArgs; - if (!this.isType(argument, parameter.type) && argument !== null) { - if (parameter.optional) { - continue; - } - next = true; - } else if (arg.type === '=') { - throw new pytorch.Error('Expected named argument.'); - } else { - copyArgs.shift(); - copyEvalArgs.shift(); - } + } + continue; + } + if (next) { + break; + } + const parameter = parameters.shift(); + if (parameter.kwarg_only) { + kwarg_only = true; + } + const [argument] = copyEvalArgs; + /* if (parameter.type === 'Tensor' || (parameter.type === 'Scalar' && pytorch.Utility.isTensor(argument))) { + if (Array.isArray(argument) || (!pytorch.Utility.isTensor(argument) && argument !== null && argument !== undefined)) { + if (parameter.optional) { + continue; } - if (next) { - break; + next = true; + } else { + copyArgs.shift(); + copyEvalArgs.shift(); + } + } else */ + if (parameter.optional === true && + (parameter.type === 'float32' || parameter.type === 'boolean' || parameter.type === 'int64' || parameter.type === 'complex' || parameter.type === 'ScalarType' || parameter.type === 'Device' || parameter.type === 'Layout') && + argument instanceof torch.Value && argument.type() instanceof torch.NoneType) { + copyArgs.shift(); + copyEvalArgs.shift(); + } else if (parameter.type === 'Tensor[]') { + const [argument] = copyEvalArgs; + if ((argument instanceof torch.Value && this.fromType(argument.type()) === 'Tensor[]') || + (Array.isArray(argument) && argument.every((item) => pytorch.Utility.isTensor(item) || item === null || (item instanceof torch.Value && item.type() instanceof torch.TensorType)))) { + copyArgs.shift(); + copyEvalArgs.shift(); + } else { + if (parameter.optional) { + continue; } + next = true; } - if (next) { - continue; + } else if (parameter.type === 't[]') { + if (!Array.isArray(argument) && (argument instanceof torch.Value === false || argument.type() instanceof torch.ListType === false)) { + if (parameter.optional) { + continue; + } + next = true; + } else { + copyArgs.shift(); + copyEvalArgs.shift(); } - for (let i = 0; i < schema.outputs.length; i++) { - const parameter = schema.outputs[i]; - switch (parameter.type) { - case 'Scalar': - case 'Tensor': - case 'Tensor[]': - break; - // case 'int64': - // break; - case '__torch__.torch.classes.xnnpack.LinearOpContext': - case '__torch__.torch.classes.xnnpack.Conv2dOpContext': - case '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext': - case '__torch__.torch.classes.rnn.CellParamsBase': - case '__torch__.torch.classes.quantized.LinearPackedParamsBase': - case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase': - case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase': - break; - default: { - if (!outputTypes || schema.outputs.length !== 1 || schema.outputs[0].type !== outputTypes[0]) { - next = true; - } - break; - } + } else { + const [arg] = copyArgs; + if (!this.isType(argument, parameter.type) && argument !== null) { + if (parameter.optional) { + continue; } + next = true; + } else if (arg.type === '=') { + next = true; + // throw new pytorch.Error('Expected named argument.'); + } else { + copyArgs.shift(); + copyEvalArgs.shift(); } - if (next) { - continue; + } + if (next) { + break; + } + } + if (next) { + continue; + } + if (!kwarg_only && parameters.some((parameter) => parameter.default === undefined)) { + continue; + } + for (let i = 0; i < schema.outputs.length; i++) { + const parameter = schema.outputs[i]; + switch (parameter.type) { + case 'Scalar': + case 'Tensor': + case 'Tensor[]': + case 'float32': + case 'float32[]': + case 'int64': + case 'int64[]': + case 'Device': + case 'boolean': + case 'boolean[]': + case 't': + case 't[]': + case 'complex': + case 'string': + case 'string[]': + case 'Dict(string, Tensor)': + case 'Dict(string, t)': + case 'Any': + break; + case '__torch__.torch.classes.xnnpack.LinearOpContext': + case '__torch__.torch.classes.xnnpack.Conv2dOpContext': + case '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext': + case '__torch__.torch.classes.rnn.CellParamsBase': + case '__torch__.torch.classes.quantized.LinearPackedParamsBase': + case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase': + case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase': + break; + default: { + throw new pytorch.Error(`Unknown return type '${parameter.type}'.`); + // if (!outputTypes || schema.outputs.length !== 1 || schema.outputs[0].type !== outputTypes[0]) { + // next = true; + // } + // break; } - return [schema, args, evalArgs]; } } + if (next) { + continue; + } + matches.push(schema); } - return null; + if (matches.length > 1) { + const keys = new Map([['int64', 1], ['float32', 2], ['Scalar', 3]]); + matches.sort((a, b) => { + let keyA = keys.get(a.inputs[0].type) || 4; + let keyB = keys.get(b.inputs[0].type) || 4; + if (keyA === keyB && a.inputs.length > 1 && b.inputs.length > 1) { + keyA = keys.get(a.inputs[1].type) || 4; + keyB = keys.get(b.inputs[1].type) || 4; + } + return keyA - keyB; + }); + } + if (matches.length > 0) { + return [matches[0], evalArgs]; + } + throw new pytorch.Error(`Unknown function '${type}'.`); + // console.log(` ${type}`); + // return null; } block(statements, context) { @@ -3523,7 +4114,9 @@ pytorch.jit.FlatBuffersLoader = class { for (const [name, value] of this._all_functions) { const class_index = module.ivalues[name].val.class_type; const class_type = this._all_types[class_index]; - class_type.addMethod(value); + if (value) { + class_type.addMethod(value); + } } m._min_operator_version = module.operator_version; m._bytecode_version = module.bytecode_version; @@ -3734,6 +4327,28 @@ pytorch.Utility = class { } } + static toType(type) { + if (pytorch.Utility.isInstance(type, 'torch.ListType')) { + return `${pytorch.Utility.toType(type.getElementType())}[]`; + } + if (pytorch.Utility.isInstance(type, 'torch.IntType')) { + return `int64`; + } + if (pytorch.Utility.isInstance(type, 'torch.FloatType')) { + return `float32`; + } + if (pytorch.Utility.isInstance(type, 'torch.StringType')) { + return `string`; + } + if (pytorch.Utility.isInstance(type, 'torch.ComplexType')) { + return `complex`; + } + if (pytorch.Utility.isInstance(type, 'torch.TensorType')) { + return `Tensor`; + } + throw new pytorch.Error(`Unsupported type '${type.kind()}'.`); + } + static isObjectType(type) { switch (type) { case '__torch__.torch.classes.xnnpack.LinearOpContext': diff --git a/test/models.json b/test/models.json index 8563a8aead..35c3be23a1 100644 --- a/test/models.json +++ b/test/models.json @@ -5243,7 +5243,6 @@ "target": "deeplabv3_scripted.ptl", "source": "https://github.com/lutzroeder/netron/files/9562007/deeplabv3_scripted.ptl.zip[deeplabv3_scripted.ptl]", "format": "TorchScript v1.6", - "assert": "model.graphs[0].nodes[0].inputs[1].value.type.name == '__torch__.torch.classes.xnnpack.Conv2dOpContext'", "link": "https://github.com/lutzroeder/netron/issues/842" }, { @@ -5334,7 +5333,7 @@ "type": "pytorch", "target": "fasterrcnn_resnet50_fpn.pt", "source": "https://github.com/lutzroeder/netron/files/7677467/fasterrcnn_resnet50_fpn.pt.zip[fasterrcnn_resnet50_fpn.pt]", - "error": "Unsupported torch.add expression type.", + "error": "Unknown function 'torch.items'.", "link": "https://github.com/lutzroeder/netron/issues/689" }, { @@ -5491,7 +5490,7 @@ "target": "mask_model.pt", "source": "https://github.com/lutzroeder/netron/files/10080302/mask_model.pt.zip[mask_model.pt]", "format": "TorchScript v1.7", - "error": "Unsupported torch.add expression type.", + "error": "Unknown function 'torch.items'.", "link": "https://github.com/lutzroeder/netron/issues/842" }, { @@ -5785,7 +5784,7 @@ "type": "pytorch", "target": "netron_issue_547_1.pt", "source": "https://github.com/lutzroeder/netron/files/5137393/netron_issue_547_1.zip[netron_issue_547_1.pt]", - "error": "Unsupported torch.add expression type.", + "error": "Unknown function 'torch.add'.", "link": "https://github.com/lutzroeder/netron/issues/547" }, { @@ -5821,7 +5820,6 @@ "target": "opt_xx.pt", "source": "https://github.com/lutzroeder/netron/files/8747908/opt_xx.pt.zip[opt_xx.pt]", "format": "TorchScript v1.6", - "error": "Slicing expected array", "link": "https://github.com/lutzroeder/netron/issues/913" }, { @@ -5849,7 +5847,7 @@ "type": "pytorch", "target": "pyg_model.pt", "source": "https://github.com/lutzroeder/netron/files/10369483/pyg_model.zip[pyg_model.pt]", - "error": "Expected \\'edge_index\\' to be of integer type (got \\'6\\')", + "error": "Unknown function 'torch.linear'.", "link": "https://github.com/lutzroeder/netron/issues/546" }, { @@ -5864,7 +5862,6 @@ "target": "quant_3d.pt", "source": "https://github.com/lutzroeder/netron/files/5877566/quant_3d.pt.zip[quant_3d.pt]", "format": "TorchScript v1.6", - "assert": "model.graphs[0].nodes[1].inputs[1].value.type.name == '__torch__.torch.classes.quantized.Conv3dPackedParamsBase'", "link": "https://github.com/lutzroeder/netron/issues/546" }, { @@ -5892,7 +5889,7 @@ "type": "pytorch", "target": "rcnn.pt", "source": "https://github.com/lutzroeder/netron/files/9035740/rcnn.pt.zip[rcnn.pt]", - "error": "AssertionError: expecting the last two dimensions of the Tensor to be H and W instead got []", + "error": "Unknown function 'torch.items'.", "link": "https://github.com/lutzroeder/netron/issues/842" }, { @@ -6346,7 +6343,6 @@ "target": "transformer.pt", "source": "https://github.com/lutzroeder/netron/files/10271969/transformer.pt.zip[transformer.pt]", "format": "TorchScript v1.6", - "error": "AssertionError: was expecting embedding dimension of 512, but got ?", "link": "https://github.com/lutzroeder/netron/issues/842" }, { @@ -6502,8 +6498,7 @@ "type": "pytorch", "target": "yolox_m.torchscript.pt", "source": "https://github.com/lutzroeder/netron/files/15031984/yolox_m.torchscript.pt.zip[yolox_m.torchscript.pt]", - "format": "TorchScript v1.5", - "error": "ValueError: not enough values to unpack (expected 3, actual 1).", + "format": "TorchScript v1.6", "link": "https://github.com/lutzroeder/netron/issues/842" }, { diff --git a/tools/pytorch_script.py b/tools/pytorch_script.py index 8343f8f491..d9e2f25265 100644 --- a/tools/pytorch_script.py +++ b/tools/pytorch_script.py @@ -59,20 +59,214 @@ def _write_metadata(value): ('aten/src/ATen/native/RNN.cpp', re.compile(r'TORCH_SELECTIVE_SCHEMA\("(.*)"', re.MULTILINE)), ('torch/jit/_shape_functions.py', - re.compile(r'(prim::.*->\s*.*)"', re.MULTILINE)) + re.compile(r'(prim::.*->\s*.*)"', re.MULTILINE)), + ('torch/csrc/jit/runtime/static/native_ops.cpp', + re.compile(r'(prim::.*->\s*.*)"', re.MULTILINE)), ] known_schema_definitions = [ - 'aten::as_tensor(Tensor(a) data, *, ScalarType? dtype=None, Device? device=None) -> Tensor(b|a)', # pylint: disable=line-too-long + 'aten::__and__.bool(bool a, bool b) -> bool', + 'aten::__and__.int(int a, int b) -> int', + 'aten::__and__.Scalar(Tensor self, Scalar other) -> Tensor', + 'aten::__and__.Tensor(Tensor self, Tensor other) -> Tensor', + 'aten::__getitem__.Dict_bool(Dict(bool, t) self, bool key) -> t(*)', + 'aten::__getitem__.Dict_complex(Dict(complex, t) self, complex key) -> t(*)', + 'aten::__getitem__.Dict_float(Dict(float, t) self, float key) -> t(*)', + 'aten::__getitem__.Dict_int(Dict(int, t) self, int key) -> t(*)', + 'aten::__getitem__.Dict_str(Dict(str, t) self, str key) -> t(*)', + 'aten::__getitem__.Dict_Tensor(Dict(Tensor, t) self, Tensor key) -> t(*)', + 'aten::__getitem__.str(str s, int index) -> str', + 'aten::__getitem__.t(t[](a) list, int idx) -> t(*)', + 'prim::abs(Tensor x) -> Tensor', + 'prim::abs.complex(complex a) -> float', + 'prim::abs.float(float a) -> float', + 'prim::abs.int(int a) -> int', + 'prim::abs.Scalar(Scalar a) -> Scalar', + 'aten::any(Tensor self) -> Tensor', + 'aten::any.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)', + 'aten::any.bool(bool[] self) -> bool', + 'aten::any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor', + 'aten::any.dimname(Tensor self, str dim, bool keepdim=False) -> Tensor', + 'aten::any.dimname_out(Tensor self, str dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)', # pylint: disable=line-too-long + 'aten::any.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor', + 'aten::any.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)', # pylint: disable=line-too-long + 'aten::any.float(float[] self) -> bool', + 'aten::any.int(int[] self) -> bool', + 'aten::any.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)', + 'aten::any.str(str[] self) -> bool', 'aten::as_tensor.bool(bool t, *, ScalarType? dtype=None, Device? device=None) -> Tensor', 'aten::as_tensor.complex(complex t, *, ScalarType? dtype=None, Device? device=None) -> Tensor', 'aten::as_tensor.float(float t, *, ScalarType? dtype=None, Device? device=None) -> Tensor', 'aten::as_tensor.int(int t, *, ScalarType? dtype=None, Device? device=None) -> Tensor', 'aten::as_tensor.list(t[] data, *, ScalarType? dtype=None, Device? device=None) -> Tensor', - 'aten::searchsorted.Tensor(Tensor sorted_sequence, Tensor self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None) -> Tensor', # pylint: disable=line-too-long - 'aten::searchsorted.Tensor_out(Tensor sorted_sequence, Tensor self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None, Tensor(a!) out) -> Tensor(a!)', # pylint: disable=line-too-long - 'aten::searchsorted.Scalar(Tensor sorted_sequence, Scalar self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None) -> Tensor', # pylint: disable=line-too-long + 'aten::as_tensor(Tensor(a) data, *, ScalarType? dtype=None, Device? device=None) -> Tensor(b|a)', # pylint: disable=line-too-long + 'aten::ceil.float(float a) -> int', + 'aten::ceil.int(int a) -> int', + 'aten::ceil.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)', + 'aten::ceil.Scalar(Scalar a) -> Scalar', + 'aten::ceil(Tensor self) -> Tensor', + 'aten::dict.Dict_str(Dict(str, t)(a) self) -> Dict(str, t)', + 'aten::dict() -> Dict(str, Tensor)', + 'aten::div.complex(complex a, complex b) -> complex', + 'aten::div.float(float a, float b) -> float', + 'aten::div.int(int a, int b) -> float', + 'aten::div.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!)', # pylint: disable=line-too-long + 'aten::div.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)', + 'aten::div.Scalar_mode_out(Tensor self, Scalar other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!)', # pylint: disable=line-too-long + 'aten::div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor', + 'aten::div.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)', + 'aten::div.Scalar(Tensor self, Scalar other) -> Tensor', + 'aten::div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor', + 'aten::div.Tensor(Tensor self, Tensor other) -> Tensor', + 'aten::div(Scalar a, Scalar b) -> float', + 'aten::eq(Scalar a, Scalar b) -> bool', + 'aten::eq.bool(bool a, bool b) -> bool', + 'aten::eq.bool_list(bool[] a, bool[] b) -> bool', + 'aten::eq.complex(complex a, complex b) -> bool', + 'aten::eq.complex_float(complex a, float b) -> bool', + 'aten::eq.device(Device a, Device b) -> bool', + 'aten::eq.enum(AnyEnumType a, AnyEnumType b) -> bool', + 'aten::eq.float(float a, float b) -> bool', + 'aten::eq.float_complex(float a, complex b) -> bool', + 'aten::eq.float_int(float a, int b) -> bool', + 'aten::eq.float_list(float[] a, float[] b) -> bool', + 'aten::eq.int(int a, int b) -> bool', + 'aten::eq.int_float(int a, float b) -> bool', + 'aten::eq.int_list(int[] a, int[] b) -> bool', + 'aten::eq.Scalar(Tensor self, Scalar other) -> Tensor', + 'aten::eq.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)', + 'aten::eq.str(str a, str b) -> bool', + 'aten::eq.str_list(str[] a, str[] b) -> bool', + 'aten::eq.Tensor(Tensor self, Tensor other) -> Tensor', + 'aten::eq.Tensor_list(Tensor[] a, Tensor[] b) -> bool', + 'aten::eq.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)', + 'aten::eq_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)', + 'aten::eq_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)', + 'aten::equal(Tensor self, Tensor other) -> bool', + 'aten::extend.t(t[](a!) self, t[] other) -> ()', + 'aten::gt(Scalar a, Scalar b) -> bool', + 'aten::gt.float(float a, float b) -> bool', + 'aten::gt.float_int(float a, int b) -> bool', + 'aten::gt.int(int a, int b) -> bool', + 'aten::gt.int_float(int a, float b) -> bool', + 'aten::gt.Scalar(Tensor self, Scalar other) -> Tensor', + 'aten::gt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)', + 'aten::gt.str(str a, str b) -> bool', + 'aten::gt.Tensor(Tensor self, Tensor other) -> Tensor', + 'aten::gt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)', + 'aten::gt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)', + 'aten::gt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)', + 'aten::item(Tensor self) -> Scalar', + # 'aten::items.bool(Dict(bool, t) self) -> ((bool, t)[])', + # 'aten::items.complex(Dict(complex, t) self) -> ((complex, t)[])', + # 'aten::items.float(Dict(float, t) self) -> ((float, t)[])', + # 'aten::items.int(Dict(int, t) self) -> ((int, t)[])', + # 'aten::items.str(Dict(str, t) self) -> ((str, t)[])', + # 'aten::items.Tensor(Dict(Tensor, t) self) -> ((Tensor, t)[])', + 'aten::keys.bool(Dict(bool, t) self) -> bool[](*)', + 'aten::keys.complex(Dict(complex, t) self) -> complex[](*)', + 'aten::keys.float(Dict(float, t) self) -> float[](*)', + 'aten::keys.int(Dict(int, t) self) -> int[](*)', + 'aten::keys.str(Dict(str, t) self) -> str[](*)', + 'aten::keys.Tensor(Dict(Tensor, t) self) -> Tensor[](*)', + 'aten::log10.complex(complex a) -> complex', + 'aten::log10.float(float a) -> float', + 'aten::log10.int(int a) -> float', + 'aten::log10.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)', + 'aten::log10.Scalar(Scalar a) -> Scalar', + 'aten::log10(Tensor self) -> Tensor', + 'aten::le(Scalar a, Scalar b) -> bool', + 'aten::le.float(float a, float b) -> bool', + 'aten::le.float_int(float a, int b) -> bool', + 'aten::le.int(int a, int b) -> bool', + 'aten::le.int_float(int a, float b) -> bool', + 'aten::le.Scalar(Tensor self, Scalar other) -> Tensor', + 'aten::le.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)', + 'aten::le.str(str a, str b) -> bool', + 'aten::le.Tensor(Tensor self, Tensor other) -> Tensor', + 'aten::le.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)', + 'aten::le_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)', + 'aten::le_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)', + 'aten::lt(Scalar a, Scalar b) -> bool', + 'aten::lt.float(float a, float b) -> bool', + 'aten::lt.float_int(float a, int b) -> bool', + 'aten::lt.int(int a, int b) -> bool', + 'aten::lt.int_float(int a, float b) -> bool', + 'aten::lt.Scalar(Tensor self, Scalar other) -> Tensor', + 'aten::lt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)', + 'aten::lt.str(str a, str b) -> bool', + 'aten::lt.Tensor(Tensor self, Tensor other) -> Tensor', + 'aten::lt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)', + 'aten::lt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)', + 'aten::lt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)', + 'aten::remainder.float_int(float a, int b) -> float', + 'aten::remainder.float(float a, float b) -> float', + 'aten::remainder.int_float(int a, float b) -> float', + 'aten::remainder.int(int a, int b) -> int', + 'aten::remainder.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)', + 'aten::remainder.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)', + 'aten::remainder.Scalar_Tensor(Scalar self, Tensor other) -> Tensor', + 'aten::remainder.Scalar(Tensor self, Scalar other) -> Tensor', + 'aten::remainder.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)', + 'aten::remainder.Tensor(Tensor self, Tensor other) -> Tensor', + 'aten::remainder(Scalar a, Scalar b) -> Scalar', + 'aten::replace(str self, str old, str new, int max=-1) -> str', 'aten::searchsorted.Scalar_out(Tensor sorted_sequence, Scalar self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None, Tensor(a!) out) -> Tensor(a!)', # pylint: disable=line-too-long + 'aten::searchsorted.Scalar(Tensor sorted_sequence, Scalar self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None) -> Tensor', # pylint: disable=line-too-long + 'aten::searchsorted.Tensor_out(Tensor sorted_sequence, Tensor self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None, Tensor(a!) out) -> Tensor(a!)', # pylint: disable=line-too-long + 'aten::searchsorted.Tensor(Tensor sorted_sequence, Tensor self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None) -> Tensor', # pylint: disable=line-too-long + 'aten::sqrt(Tensor self) -> Tensor', + 'aten::sqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)', + 'aten::sqrt.int(int a) -> float', + 'aten::sqrt.float(float a) -> float', + 'aten::sqrt.complex(complex a) -> complex', + 'aten::sqrt.Scalar(Scalar a) -> Scalar', + 'aten::values.bool(Dict(bool, t) self) -> t[](*)', + 'aten::values.complex(Dict(complex, t) self) -> t[](*)', + 'aten::values.float(Dict(float, t) self) -> t[](*)', + 'aten::values.int(Dict(int, t) self) -> t[](*)', + 'aten::values.str(Dict(str, t) self) -> t[](*)', + 'aten::values.Tensor(Dict(Tensor, t) self) -> t[](*)', + 'aten::values(Tensor(a) self) -> Tensor(a)', + 'prim::is_cpu(Tensor a) -> bool', + 'prim::is_cuda(Tensor a) -> bool', + 'prim::is_ipu(Tensor a) -> bool', + 'prim::is_maia(Tensor a) -> bool', + 'prim::is_meta(Tensor a) -> bool', + 'prim::is_mkldnn(Tensor a) -> bool', + 'prim::is_mps(Tensor a) -> bool', + 'prim::is_mtia(Tensor a) -> bool', + 'prim::is_nested(Tensor a) -> bool', + 'prim::is_quantized(Tensor a) -> bool', + 'prim::is_sparse(Tensor a) -> bool', + 'prim::is_sparse_csr(Tensor a) -> bool', + 'prim::is_vulkan(Tensor a) -> bool', + 'prim::is_xla(Tensor a) -> bool', + 'prim::is_xpu(Tensor a) -> bool', + 'prim::itemsize(Tensor a) -> int', + 'prim::layout(Tensor a) -> Layout', + 'prim::max(Scalar a, Scalar b) -> Scalar', + 'prim::max.bool_list(bool[] l, bool[] r) -> bool[]', + 'prim::max.float(float a, float b) -> float', + 'prim::max.float_int(float a, int b) -> float', + 'prim::max.float_list(float[] l, float[] r) -> float[]', + 'prim::max.int(int a, int b) -> int', + 'prim::max.int_float(int a, float b) -> float', + 'prim::max.int_list(int[] l, int[] r) -> int[]', + 'prim::max.self_bool(bool[] self) -> bool', + 'prim::max.self_float(float[] self) -> float', + 'prim::max.self_int(int[] self) -> int', + 'prim::min(Scalar a, Scalar b) -> Scalar', + 'prim::min.bool_list(bool[] l, bool[] r) -> bool[]', + 'prim::min.float(float a, float b) -> float', + 'prim::min.float_int(float a, int b) -> float', + 'prim::min.float_list(float[] l, float[] r) -> float[]', + 'prim::min.int(int a, int b) -> int', + 'prim::min.int_float(int a, float b) -> float', + 'prim::min.int_list(int[] l, int[] r) -> int[]', + 'prim::min.self_bool(bool[] self) -> bool', + 'prim::min.self_float(float[] self) -> float', + 'prim::min.self_int(int[] self) -> int', ] def _parse_schemas(): @@ -99,7 +293,6 @@ def _parse_schemas(): return schemas def _filter_schemas(schemas, types): - keys = set(map(lambda _: _.split('.')[0], types.keys())) filtered_schemas = set() for schema in schemas.values(): @@ -162,23 +355,34 @@ def _check_types(types, schemas): 'aten::arange.start_out_', 'aten::classes._nnapi.Compilation', 'aten::fft', - 'aten::gt.float_int', - 'aten::gt.float', - 'aten::gt.int_float', - 'aten::gt.int', - 'aten::le.float_int', - 'aten::le.float', - 'aten::le.int_float', - 'aten::le.int', + 'aten::floor.float', + 'aten::floor.int', + 'aten::floor.Scalar', + 'aten::floordiv.float_int', + 'aten::floordiv.float', + 'aten::floordiv.int_float', + 'aten::floordiv.int', + 'aten::floordiv.Scalar', + 'aten::grid_sampler.legacy', + 'aten::mul.float_int', + 'aten::mul.int_float', + 'aten::mul.int', 'aten::mul.ScalarT', - 'aten::remainder.float32', - 'aten::remainder.int', + 'aten::mul', + 'aten::ne.float', + 'aten::ne.int', + 'aten::ne.str', + 'aten::neg.complex', + 'aten::neg.float', + 'aten::neg.int', + 'aten::neg.Scalar', 'aten::sub.float', 'aten::sub.int', 'aten::sub.str', 'aten::tensor.bool', 'aten::tensor.float', - 'aten::tensor.int' + 'aten::tensor.int', + 'prim::shape', ] for key in known_keys: types.pop(key)