loading configures saved architecture

This commit is contained in:
trian-gles 2024-06-28 09:11:01 +02:00
parent 9e14de36fb
commit ea17b45f1c

18
tf.js
View File

@ -7,16 +7,15 @@ const inputShape = parseInt(args[0]);
const outputShape = parseInt(args[1]);
const hiddenSize = parseInt(args[2]);
// Define a model for linear regression.
var model = tf.sequential();
model.add(tf.layers.dense({units: hiddenSize, inputShape: [inputShape]}));
model.add(tf.layers.dense({units: outputShape}));
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
// Generate some synthetic data for training.
var xsArr = [];
@ -24,6 +23,7 @@ var ysArr = [];
maxApi.addHandler("train", (epochs) => {
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
// aggregate data
const xs = tf.tensor2d(xsArr, [xsArr.length, inputShape]);
@ -61,15 +61,22 @@ async function getWeights() {
maxApi.addHandler("save", (dictId, key) => {
maxApi.getDict(dictId).then((dict) => {
getWeights().then((weights) => {
dict[key] = weights;
dict[key] = {}
dict[key].weights = weights;
dict[key].model = model.toJSON(null, false);
maxApi.setDict(dictId, dict);
});
});
});
maxApi.addHandler("load", (dictId, key) => {
maxApi.getDict(dictId).then((dict) => {
let data = dict[key];
let arch = dict[key].model;
tf.models.modelFromJSON(dict[key].model).then((m) => {
model = m
let data = dict[key].weights;
let tensors = [];
data.forEach(item => {
let shape = item.shape;
@ -83,4 +90,5 @@ maxApi.addHandler("load", (dictId, key) => {
model.setWeights(tensors);
});
});
});