diff --git a/tf.js b/tf.js index 72095d2..c7803f3 100644 --- a/tf.js +++ b/tf.js @@ -8,7 +8,7 @@ const outputShape = parseInt(args[1]); const hiddenSize = parseInt(args[2]); // Define a model for linear regression. -const model = tf.sequential(); +var model = tf.sequential(); @@ -47,3 +47,41 @@ maxApi.addHandler("predict", (...data) => { }); +async function getWeights() { + let weights = model.getWeights(); + for (let i = 0; i < weights.length; i++){ + let data = await weights[i].data(); + let shape = weights[i].shape; + weights[i] = {data, shape}; + + } + return weights; +} + +maxApi.addHandler("save", (dictId, key) => { + maxApi.getDict(dictId).then((dict) => { + getWeights().then((weights) => { + console.log(weights); + dict[key] = weights; + maxApi.setDict(dictId, dict); + }); + }); +}); + +maxApi.addHandler("load", (dictId, key) => { + maxApi.getDict(dictId).then((dict) => { + let data = dict[key]; + let tensors = []; + data.forEach(item => { + let shape = item.shape; + let vals = []; + for (const [key, value] of Object.entries(item.data)) { + vals.push(value); + + } + tensors.push(tf.tensor(vals, shape)); + }); + model.setWeights(tensors); + }); +}); + diff --git a/tf.maxpat b/tf.maxpat index 29ba1c7..eae609b 100644 --- a/tf.maxpat +++ b/tf.maxpat @@ -10,7 +10,7 @@ } , "classnamespace" : "box", - "rect" : [ 1649.0, 182.0, 948.0, 575.0 ], + "rect" : [ 164.0, 165.0, 948.0, 575.0 ], "bglocked" : 0, "openinpresentation" : 0, "default_fontsize" : 12.0, @@ -39,6 +39,117 @@ "subpatcher_template" : "", "assistshowspatchername" : 0, "boxes" : [ { + "box" : { + "id" : "obj-40", + "linecount" : 2, + "maxclass" : "message", + "numinlets" : 2, + "numoutlets" : 1, + "outlettype" : [ "" ], + "patching_rect" : [ 237.0, 453.0, 50.0, 36.0 ], + "text" : "9. 10. 11." + } + + } +, { + "box" : { + "id" : "obj-36", + "maxclass" : "message", + "numinlets" : 2, + "numoutlets" : 1, + "outlettype" : [ "" ], + "patching_rect" : [ 768.0, 156.0, 111.0, 22.0 ], + "text" : "load weights kieran" + } + + } +, { + "box" : { + "id" : "obj-35", + "maxclass" : "message", + "numinlets" : 2, + "numoutlets" : 1, + "outlettype" : [ "" ], + "patching_rect" : [ 626.0, 156.0, 114.0, 22.0 ], + "text" : "save weights kieran" + } + + } +, { + "box" : { + "data" : { + "kieran" : [ { + "data" : { + "0" : 0.927409112453461, + "1" : -0.10581861436367, + "2" : 0.110250025987625, + "3" : -1.001458287239075, + "4" : 0.689840376377106, + "5" : 0.396656483411789, + "6" : -1.090744495391846, + "7" : -0.033253565430641 + } +, + "shape" : [ 2, 4 ] + } +, { + "data" : { + "0" : -0.04672335088253, + "1" : 0.441919326782227, + "2" : -0.127905443310738, + "3" : 0.395313948392868 + } +, + "shape" : [ 4 ] + } +, { + "data" : { + "0" : 0.714273393154144, + "1" : 0.45893207192421, + "2" : 0.772015511989594, + "3" : -0.403766572475433, + "4" : 0.378416150808334, + "5" : 0.486913919448853, + "6" : -0.327280908823013, + "7" : -0.424903750419617, + "8" : -0.59527575969696, + "9" : -0.619619905948639, + "10" : -0.706632375717163, + "11" : -0.025489492341876 + } +, + "shape" : [ 4, 3 ] + } +, { + "data" : { + "0" : -0.584373593330383, + "1" : 0.079081602394581, + "2" : 0.754254102706909 + } +, + "shape" : [ 3 ] + } + ] + } +, + "id" : "obj-33", + "maxclass" : "newobj", + "numinlets" : 2, + "numoutlets" : 5, + "outlettype" : [ "dictionary", "", "", "", "" ], + "patching_rect" : [ 768.0, 215.0, 133.0, 22.0 ], + "saved_object_attributes" : { + "embed" : 1, + "legacy" : 0, + "parameter_enable" : 0, + "parameter_mappable" : 0 + } +, + "text" : "dict weights @embed 1" + } + + } +, { "box" : { "id" : "obj-32", "maxclass" : "comment", @@ -68,7 +179,7 @@ "maxclass" : "comment", "numinlets" : 1, "numoutlets" : 0, - "patching_rect" : [ 339.0, 44.0, 150.0, 34.0 ], + "patching_rect" : [ 339.0, 44.0, 152.0, 34.0 ], "text" : "Uses basic SGD optimizer with 0.01 learning rate" } @@ -150,7 +261,6 @@ "numoutlets" : 1, "outlettype" : [ "" ], "patching_rect" : [ 262.0, 163.0, 55.0, 22.0 ], - "presentation_linecount" : 3, "text" : "3 3 5 6 7" } @@ -163,7 +273,6 @@ "numoutlets" : 1, "outlettype" : [ "" ], "patching_rect" : [ 330.0, 163.0, 55.0, 22.0 ], - "presentation_linecount" : 3, "text" : "4 4 7 8 9" } @@ -176,7 +285,6 @@ "numoutlets" : 1, "outlettype" : [ "" ], "patching_rect" : [ 197.0, 163.0, 55.0, 22.0 ], - "presentation_linecount" : 3, "text" : "2 2 3 4 5" } @@ -192,19 +300,6 @@ "text" : "1 1 1 2 3" } - } -, { - "box" : { - "id" : "obj-11", - "linecount" : 6, - "maxclass" : "message", - "numinlets" : 2, - "numoutlets" : 1, - "outlettype" : [ "" ], - "patching_rect" : [ 192.0, 413.0, 50.0, 91.0 ], - "text" : "8.914521 10.032911 11.035411" - } - } , { "box" : { @@ -246,7 +341,7 @@ "numoutlets" : 1, "offset" : [ 0.0, 0.0 ], "outlettype" : [ "bang" ], - "patching_rect" : [ 484.0, 353.0, 400.0, 220.0 ], + "patching_rect" : [ 526.0, 284.0, 400.0, 220.0 ], "viewvisibility" : 1 } @@ -275,8 +370,6 @@ "args" : [ 2, 3, 4 ], "autostart" : 1, "defer" : 0, - "node_bin_path" : "", - "npm_bin_path" : "", "watch" : 1 } , @@ -287,7 +380,7 @@ ], "lines" : [ { "patchline" : { - "destination" : [ "obj-11", 1 ], + "destination" : [ "obj-40", 1 ], "source" : [ "obj-1", 0 ] } @@ -368,6 +461,20 @@ "source" : [ "obj-20", 0 ] } + } +, { + "patchline" : { + "destination" : [ "obj-1", 0 ], + "source" : [ "obj-35", 0 ] + } + + } +, { + "patchline" : { + "destination" : [ "obj-1", 0 ], + "source" : [ "obj-36", 0 ] + } + } , { "patchline" : {