saving and loading in max dicts

This commit is contained in:
trian-gles 2024-06-27 16:00:06 +02:00
parent 847db766d2
commit 43cd89faa5
2 changed files with 168 additions and 23 deletions

40
tf.js
View File

@ -8,7 +8,7 @@ const outputShape = parseInt(args[1]);
const hiddenSize = parseInt(args[2]);
// Define a model for linear regression.
const model = tf.sequential();
var model = tf.sequential();
@ -47,3 +47,41 @@ maxApi.addHandler("predict", (...data) => {
});
async function getWeights() {
let weights = model.getWeights();
for (let i = 0; i < weights.length; i++){
let data = await weights[i].data();
let shape = weights[i].shape;
weights[i] = {data, shape};
}
return weights;
}
maxApi.addHandler("save", (dictId, key) => {
maxApi.getDict(dictId).then((dict) => {
getWeights().then((weights) => {
console.log(weights);
dict[key] = weights;
maxApi.setDict(dictId, dict);
});
});
});
maxApi.addHandler("load", (dictId, key) => {
maxApi.getDict(dictId).then((dict) => {
let data = dict[key];
let tensors = [];
data.forEach(item => {
let shape = item.shape;
let vals = [];
for (const [key, value] of Object.entries(item.data)) {
vals.push(value);
}
tensors.push(tf.tensor(vals, shape));
});
model.setWeights(tensors);
});
});

151
tf.maxpat
View File

@ -10,7 +10,7 @@
}
,
"classnamespace" : "box",
"rect" : [ 1649.0, 182.0, 948.0, 575.0 ],
"rect" : [ 164.0, 165.0, 948.0, 575.0 ],
"bglocked" : 0,
"openinpresentation" : 0,
"default_fontsize" : 12.0,
@ -39,6 +39,117 @@
"subpatcher_template" : "",
"assistshowspatchername" : 0,
"boxes" : [ {
"box" : {
"id" : "obj-40",
"linecount" : 2,
"maxclass" : "message",
"numinlets" : 2,
"numoutlets" : 1,
"outlettype" : [ "" ],
"patching_rect" : [ 237.0, 453.0, 50.0, 36.0 ],
"text" : "9. 10. 11."
}
}
, {
"box" : {
"id" : "obj-36",
"maxclass" : "message",
"numinlets" : 2,
"numoutlets" : 1,
"outlettype" : [ "" ],
"patching_rect" : [ 768.0, 156.0, 111.0, 22.0 ],
"text" : "load weights kieran"
}
}
, {
"box" : {
"id" : "obj-35",
"maxclass" : "message",
"numinlets" : 2,
"numoutlets" : 1,
"outlettype" : [ "" ],
"patching_rect" : [ 626.0, 156.0, 114.0, 22.0 ],
"text" : "save weights kieran"
}
}
, {
"box" : {
"data" : {
"kieran" : [ {
"data" : {
"0" : 0.927409112453461,
"1" : -0.10581861436367,
"2" : 0.110250025987625,
"3" : -1.001458287239075,
"4" : 0.689840376377106,
"5" : 0.396656483411789,
"6" : -1.090744495391846,
"7" : -0.033253565430641
}
,
"shape" : [ 2, 4 ]
}
, {
"data" : {
"0" : -0.04672335088253,
"1" : 0.441919326782227,
"2" : -0.127905443310738,
"3" : 0.395313948392868
}
,
"shape" : [ 4 ]
}
, {
"data" : {
"0" : 0.714273393154144,
"1" : 0.45893207192421,
"2" : 0.772015511989594,
"3" : -0.403766572475433,
"4" : 0.378416150808334,
"5" : 0.486913919448853,
"6" : -0.327280908823013,
"7" : -0.424903750419617,
"8" : -0.59527575969696,
"9" : -0.619619905948639,
"10" : -0.706632375717163,
"11" : -0.025489492341876
}
,
"shape" : [ 4, 3 ]
}
, {
"data" : {
"0" : -0.584373593330383,
"1" : 0.079081602394581,
"2" : 0.754254102706909
}
,
"shape" : [ 3 ]
}
]
}
,
"id" : "obj-33",
"maxclass" : "newobj",
"numinlets" : 2,
"numoutlets" : 5,
"outlettype" : [ "dictionary", "", "", "", "" ],
"patching_rect" : [ 768.0, 215.0, 133.0, 22.0 ],
"saved_object_attributes" : {
"embed" : 1,
"legacy" : 0,
"parameter_enable" : 0,
"parameter_mappable" : 0
}
,
"text" : "dict weights @embed 1"
}
}
, {
"box" : {
"id" : "obj-32",
"maxclass" : "comment",
@ -68,7 +179,7 @@
"maxclass" : "comment",
"numinlets" : 1,
"numoutlets" : 0,
"patching_rect" : [ 339.0, 44.0, 150.0, 34.0 ],
"patching_rect" : [ 339.0, 44.0, 152.0, 34.0 ],
"text" : "Uses basic SGD optimizer with 0.01 learning rate"
}
@ -150,7 +261,6 @@
"numoutlets" : 1,
"outlettype" : [ "" ],
"patching_rect" : [ 262.0, 163.0, 55.0, 22.0 ],
"presentation_linecount" : 3,
"text" : "3 3 5 6 7"
}
@ -163,7 +273,6 @@
"numoutlets" : 1,
"outlettype" : [ "" ],
"patching_rect" : [ 330.0, 163.0, 55.0, 22.0 ],
"presentation_linecount" : 3,
"text" : "4 4 7 8 9"
}
@ -176,7 +285,6 @@
"numoutlets" : 1,
"outlettype" : [ "" ],
"patching_rect" : [ 197.0, 163.0, 55.0, 22.0 ],
"presentation_linecount" : 3,
"text" : "2 2 3 4 5"
}
@ -192,19 +300,6 @@
"text" : "1 1 1 2 3"
}
}
, {
"box" : {
"id" : "obj-11",
"linecount" : 6,
"maxclass" : "message",
"numinlets" : 2,
"numoutlets" : 1,
"outlettype" : [ "" ],
"patching_rect" : [ 192.0, 413.0, 50.0, 91.0 ],
"text" : "8.914521 10.032911 11.035411"
}
}
, {
"box" : {
@ -246,7 +341,7 @@
"numoutlets" : 1,
"offset" : [ 0.0, 0.0 ],
"outlettype" : [ "bang" ],
"patching_rect" : [ 484.0, 353.0, 400.0, 220.0 ],
"patching_rect" : [ 526.0, 284.0, 400.0, 220.0 ],
"viewvisibility" : 1
}
@ -275,8 +370,6 @@
"args" : [ 2, 3, 4 ],
"autostart" : 1,
"defer" : 0,
"node_bin_path" : "",
"npm_bin_path" : "",
"watch" : 1
}
,
@ -287,7 +380,7 @@
],
"lines" : [ {
"patchline" : {
"destination" : [ "obj-11", 1 ],
"destination" : [ "obj-40", 1 ],
"source" : [ "obj-1", 0 ]
}
@ -368,6 +461,20 @@
"source" : [ "obj-20", 0 ]
}
}
, {
"patchline" : {
"destination" : [ "obj-1", 0 ],
"source" : [ "obj-35", 0 ]
}
}
, {
"patchline" : {
"destination" : [ "obj-1", 0 ],
"source" : [ "obj-36", 0 ]
}
}
, {
"patchline" : {