From ea17b45f1c7186f123edc94b13a06d6d9e6acc7f Mon Sep 17 00:00:00 2001 From: trian-gles <69212477+trian-gles@users.noreply.github.com> Date: Fri, 28 Jun 2024 09:11:01 +0200 Subject: [PATCH] loading configures saved architecture --- tf.js | 38 +++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/tf.js b/tf.js index 1665451..9d60fcc 100644 --- a/tf.js +++ b/tf.js @@ -7,16 +7,15 @@ const inputShape = parseInt(args[0]); const outputShape = parseInt(args[1]); const hiddenSize = parseInt(args[2]); -// Define a model for linear regression. + var model = tf.sequential(); model.add(tf.layers.dense({units: hiddenSize, inputShape: [inputShape]})); model.add(tf.layers.dense({units: outputShape})); -model.compile({loss: 'meanSquaredError', optimizer: 'sgd'}); -// Generate some synthetic data for training. + var xsArr = []; @@ -24,6 +23,7 @@ var ysArr = []; maxApi.addHandler("train", (epochs) => { + model.compile({loss: 'meanSquaredError', optimizer: 'sgd'}); // aggregate data const xs = tf.tensor2d(xsArr, [xsArr.length, inputShape]); @@ -61,26 +61,34 @@ async function getWeights() { maxApi.addHandler("save", (dictId, key) => { maxApi.getDict(dictId).then((dict) => { getWeights().then((weights) => { - dict[key] = weights; + dict[key] = {} + dict[key].weights = weights; + dict[key].model = model.toJSON(null, false); maxApi.setDict(dictId, dict); }); }); }); + maxApi.addHandler("load", (dictId, key) => { maxApi.getDict(dictId).then((dict) => { - let data = dict[key]; - let tensors = []; - data.forEach(item => { - let shape = item.shape; - let vals = []; - for (const [key, value] of Object.entries(item.data)) { - vals.push(value); + let arch = dict[key].model; + tf.models.modelFromJSON(dict[key].model).then((m) => { + model = m - } - tensors.push(tf.tensor(vals, shape)); - }); - model.setWeights(tensors); + let data = dict[key].weights; + let tensors = []; + data.forEach(item => { + let shape = item.shape; + let vals = []; + for (const [key, value] of Object.entries(item.data)) { + vals.push(value); + + } + tensors.push(tf.tensor(vals, shape)); + }); + model.setWeights(tensors); + }); }); });