From 5ec1afe4b47a429fb4d019209c17326e53c86255 Mon Sep 17 00:00:00 2001 From: trian-gles <69212477+trian-gles@users.noreply.github.com> Date: Mon, 1 Jul 2024 00:03:51 +0200 Subject: [PATCH] native like dictionary loading behaviour --- tf.js | 19 +++- tf.maxpat | 327 +++++++++++++++++++++++++++++++++++++++++------------- 2 files changed, 262 insertions(+), 84 deletions(-) diff --git a/tf.js b/tf.js index 634361b..855e2a9 100644 --- a/tf.js +++ b/tf.js @@ -39,6 +39,7 @@ maxApi.addHandler("dataPoint", (...data) => { ysArr.push(data.slice(inputShape)); }); + maxApi.addHandler("predict", (...data) => { data.map((item) => parseFloat(item)); model.predict(tf.tensor2d([data], [1, inputShape])).array().then((value) => { @@ -74,14 +75,11 @@ async function getJson() { return json; } - -maxApi.addHandler("load", (dictId, key) => { - maxApi.getDict(dictId).then((dict) => { - let arch = dict[key].model; - tf.models.modelFromJSON(dict[key].model).then((m) => { +function loadWeights(dict){ + tf.models.modelFromJSON(dict.model).then((m) => { model = m - let data = dict[key].weights; + let data = dict.weights; let tensors = []; data.forEach(item => { let shape = item.shape; @@ -94,6 +92,15 @@ maxApi.addHandler("load", (dictId, key) => { }); model.setWeights(tensors); }); +} + + +maxApi.addHandler("set_weights", loadWeights); + + +maxApi.addHandler("load", (dictId, key) => { + maxApi.getDict(dictId).then((dict) => { + loadWeights(dict[key]); }); }); diff --git a/tf.maxpat b/tf.maxpat index b2a5b16..b9c181d 100644 --- a/tf.maxpat +++ b/tf.maxpat @@ -10,7 +10,7 @@ } , "classnamespace" : "box", - "rect" : [ 164.0, 165.0, 732.0, 575.0 ], + "rect" : [ 44.0, 125.0, 1036.0, 622.0 ], "bglocked" : 0, "openinpresentation" : 0, "default_fontsize" : 12.0, @@ -40,13 +40,85 @@ "assistshowspatchername" : 0, "boxes" : [ { "box" : { - "id" : "obj-23", + "id" : "obj-46", "maxclass" : "newobj", + "numinlets" : 1, + "numoutlets" : 0, + "patching_rect" : [ 128.0, 434.0, 32.0, 22.0 ], + "text" : "print" + } + + } +, { + "box" : { + "id" : "obj-40", + "linecount" : 5, + "maxclass" : "message", "numinlets" : 2, - "numoutlets" : 2, - "outlettype" : [ "", "" ], - "patching_rect" : [ 160.0, 336.0, 80.0, 22.0 ], - "text" : "route weights" + "numoutlets" : 1, + "outlettype" : [ "" ], + "patching_rect" : [ 199.0, 355.0, 50.0, 77.0 ], + "text" : "-2.7 -1.759611 3.244076" + } + + } +, { + "box" : { + "id" : "obj-38", + "maxclass" : "newobj", + "numinlets" : 3, + "numoutlets" : 3, + "outlettype" : [ "", "", "" ], + "patching_rect" : [ 230.0, 325.0, 98.0, 22.0 ], + "text" : "route list weights" + } + + } +, { + "box" : { + "id" : "obj-37", + "linecount" : 4, + "maxclass" : "comment", + "numinlets" : 1, + "numoutlets" : 0, + "patching_rect" : [ 510.0, 97.0, 150.0, 62.0 ], + "text" : "Send a dictionary with weights and architecture prepended by \"set_weights\" to load" + } + + } +, { + "box" : { + "id" : "obj-2", + "maxclass" : "newobj", + "numinlets" : 1, + "numoutlets" : 1, + "outlettype" : [ "" ], + "patching_rect" : [ 482.0, 192.0, 119.0, 22.0 ], + "text" : "prepend set_weights" + } + + } +, { + "box" : { + "id" : "obj-31", + "maxclass" : "button", + "numinlets" : 1, + "numoutlets" : 1, + "outlettype" : [ "bang" ], + "parameter_enable" : 0, + "patching_rect" : [ 482.0, 130.0, 24.0, 24.0 ] + } + + } +, { + "box" : { + "id" : "obj-27", + "linecount" : 3, + "maxclass" : "comment", + "numinlets" : 1, + "numoutlets" : 0, + "patching_rect" : [ 654.0, 175.0, 150.0, 48.0 ], + "text" : "output a dictionary containing model architecture and weights" } } @@ -56,7 +128,7 @@ "maxclass" : "dict.view", "numinlets" : 1, "numoutlets" : 0, - "patching_rect" : [ 28.0, 399.0, 231.0, 75.0 ] + "patching_rect" : [ 305.0, 382.0, 220.0, 214.0 ] } } @@ -79,7 +151,7 @@ "numinlets" : 2, "numoutlets" : 5, "outlettype" : [ "dictionary", "", "", "", "" ], - "patching_rect" : [ 85.0, 353.0, 61.0, 22.0 ], + "patching_rect" : [ 305.0, 355.0, 61.0, 22.0 ], "saved_object_attributes" : { "embed" : 0, "legacy" : 0, @@ -114,59 +186,144 @@ "text" : "0.5. Restart the main script" } - } -, { - "box" : { - "id" : "obj-3", - "linecount" : 2, - "maxclass" : "comment", - "numinlets" : 1, - "numoutlets" : 0, - "patching_rect" : [ 765.0, 18.0, 150.0, 34.0 ], - "text" : "4. save/load your trained models" - } - - } -, { - "box" : { - "id" : "obj-40", - "linecount" : 5, - "maxclass" : "message", - "numinlets" : 2, - "numoutlets" : 1, - "outlettype" : [ "" ], - "patching_rect" : [ 281.0, 450.0, 50.0, 77.0 ], - "text" : "weights dictionary u266000897" - } - - } -, { - "box" : { - "id" : "obj-36", - "maxclass" : "message", - "numinlets" : 2, - "numoutlets" : 1, - "outlettype" : [ "" ], - "patching_rect" : [ 785.0, 101.0, 107.0, 22.0 ], - "text" : "load myDict kieran" - } - - } -, { - "box" : { - "id" : "obj-35", - "maxclass" : "message", - "numinlets" : 2, - "numoutlets" : 1, - "outlettype" : [ "" ], - "patching_rect" : [ 783.0, 69.0, 110.0, 22.0 ], - "text" : "save myDict kieran" - } - } , { "box" : { "data" : { + "weights" : [ { + "data" : { + "0" : -0.044006392359734, + "1" : 0.037017118185759, + "2" : -0.603930592536926, + "3" : 0.060808453708887, + "4" : 0.180193305015564, + "5" : 0.282816737890244, + "6" : 0.363093823194504, + "7" : 0.173229858279228 + } +, + "shape" : [ 2, 4 ] + } +, { + "data" : { + "0" : 0, + "1" : 0, + "2" : 0, + "3" : 0 + } +, + "shape" : [ 4 ] + } +, { + "data" : { + "0" : 0.224349781870842, + "1" : 0.860164880752563, + "2" : -0.442636489868164, + "3" : -0.214278265833855, + "4" : 0.491469591856003, + "5" : -0.450522541999817, + "6" : 0.438913941383362, + "7" : -0.085629880428314, + "8" : 0.372829973697662, + "9" : -0.778826951980591, + "10" : 0.724699020385742, + "11" : 0.034776996821165 + } +, + "shape" : [ 4, 3 ] + } +, { + "data" : { + "0" : 0, + "1" : 0, + "2" : 0 + } +, + "shape" : [ 3 ] + } + ], + "model" : { + "class_name" : "Sequential", + "config" : { + "name" : "sequential_1", + "layers" : [ { + "class_name" : "Dense", + "config" : { + "units" : 4, + "activation" : "relu", + "use_bias" : 1, + "kernel_initializer" : { + "class_name" : "VarianceScaling", + "config" : { + "scale" : 1, + "mode" : "fan_avg", + "distribution" : "normal", + "seed" : null + } + + } +, + "bias_initializer" : { + "class_name" : "Zeros", + "config" : { + + } + + } +, + "kernel_regularizer" : null, + "bias_regularizer" : null, + "activity_regularizer" : null, + "kernel_constraint" : null, + "bias_constraint" : null, + "name" : "dense_Dense1", + "trainable" : 1, + "batch_input_shape" : [ null, 2 ], + "dtype" : "float32" + } + + } +, { + "class_name" : "Dense", + "config" : { + "units" : 3, + "activation" : "linear", + "use_bias" : 1, + "kernel_initializer" : { + "class_name" : "VarianceScaling", + "config" : { + "scale" : 1, + "mode" : "fan_avg", + "distribution" : "normal", + "seed" : null + } + + } +, + "bias_initializer" : { + "class_name" : "Zeros", + "config" : { + + } + + } +, + "kernel_regularizer" : null, + "bias_regularizer" : null, + "activity_regularizer" : null, + "kernel_constraint" : null, + "bias_constraint" : null, + "name" : "dense_Dense2", + "trainable" : 1 + } + + } + ] + } +, + "keras_version" : "tfjs-layers 4.20.0", + "backend" : "tensor_flow.js" + } } , @@ -175,7 +332,7 @@ "numinlets" : 2, "numoutlets" : 5, "outlettype" : [ "dictionary", "", "", "", "" ], - "patching_rect" : [ 768.0, 215.0, 129.0, 22.0 ], + "patching_rect" : [ 482.0, 161.0, 129.0, 22.0 ], "saved_object_attributes" : { "embed" : 1, "legacy" : 0, @@ -201,12 +358,12 @@ , { "box" : { "id" : "obj-30", - "linecount" : 3, + "linecount" : 4, "maxclass" : "comment", "numinlets" : 1, "numoutlets" : 0, - "patching_rect" : [ 484.0, 230.0, 150.0, 48.0 ], - "text" : "args are input shape, output shape, and size of the single hidden layer" + "patching_rect" : [ 491.0, 238.0, 150.0, 62.0 ], + "text" : "args are input shape, output shape, and size of the single hidden layer. Always Relu activation." } } @@ -346,7 +503,7 @@ "numinlets" : 2, "numoutlets" : 1, "outlettype" : [ "" ], - "patching_rect" : [ 560.0, 74.0, 65.0, 22.0 ], + "patching_rect" : [ 560.0, 56.0, 65.0, 22.0 ], "text" : "predict 5 5" } @@ -379,7 +536,7 @@ "numoutlets" : 1, "offset" : [ 0.0, 0.0 ], "outlettype" : [ "bang" ], - "patching_rect" : [ 523.0, 342.0, 400.0, 220.0 ], + "patching_rect" : [ 560.0, 353.0, 400.0, 220.0 ], "viewvisibility" : 1 } @@ -418,16 +575,7 @@ ], "lines" : [ { "patchline" : { - "destination" : [ "obj-23", 0 ], - "order" : 1, - "source" : [ "obj-1", 0 ] - } - - } -, { - "patchline" : { - "destination" : [ "obj-40", 1 ], - "order" : 0, + "destination" : [ "obj-38", 0 ], "source" : [ "obj-1", 0 ] } @@ -515,6 +663,13 @@ "source" : [ "obj-18", 2 ] } + } +, { + "patchline" : { + "destination" : [ "obj-1", 0 ], + "source" : [ "obj-2", 0 ] + } + } , { "patchline" : { @@ -529,25 +684,41 @@ "source" : [ "obj-21", 0 ] } + } +, { + "patchline" : { + "destination" : [ "obj-33", 0 ], + "source" : [ "obj-31", 0 ] + } + + } +, { + "patchline" : { + "destination" : [ "obj-2", 0 ], + "source" : [ "obj-33", 0 ] + } + } , { "patchline" : { "destination" : [ "obj-12", 0 ], - "source" : [ "obj-23", 0 ] + "source" : [ "obj-38", 1 ] } } , { "patchline" : { - "destination" : [ "obj-1", 0 ], - "source" : [ "obj-35", 0 ] + "destination" : [ "obj-40", 1 ], + "order" : 0, + "source" : [ "obj-38", 0 ] } } , { "patchline" : { - "destination" : [ "obj-1", 0 ], - "source" : [ "obj-36", 0 ] + "destination" : [ "obj-46", 0 ], + "order" : 1, + "source" : [ "obj-38", 0 ] } }