loading configures saved architecture

This commit is contained in:
trian-gles 2024-06-28 09:11:01 +02:00
parent 9e14de36fb
commit ea17b45f1c

18
tf.js
View File

@ -7,16 +7,15 @@ const inputShape = parseInt(args[0]);
const outputShape = parseInt(args[1]); const outputShape = parseInt(args[1]);
const hiddenSize = parseInt(args[2]); const hiddenSize = parseInt(args[2]);
// Define a model for linear regression.
var model = tf.sequential(); var model = tf.sequential();
model.add(tf.layers.dense({units: hiddenSize, inputShape: [inputShape]})); model.add(tf.layers.dense({units: hiddenSize, inputShape: [inputShape]}));
model.add(tf.layers.dense({units: outputShape})); model.add(tf.layers.dense({units: outputShape}));
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
// Generate some synthetic data for training.
var xsArr = []; var xsArr = [];
@ -24,6 +23,7 @@ var ysArr = [];
maxApi.addHandler("train", (epochs) => { maxApi.addHandler("train", (epochs) => {
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
// aggregate data // aggregate data
const xs = tf.tensor2d(xsArr, [xsArr.length, inputShape]); const xs = tf.tensor2d(xsArr, [xsArr.length, inputShape]);
@ -61,15 +61,22 @@ async function getWeights() {
maxApi.addHandler("save", (dictId, key) => { maxApi.addHandler("save", (dictId, key) => {
maxApi.getDict(dictId).then((dict) => { maxApi.getDict(dictId).then((dict) => {
getWeights().then((weights) => { getWeights().then((weights) => {
dict[key] = weights; dict[key] = {}
dict[key].weights = weights;
dict[key].model = model.toJSON(null, false);
maxApi.setDict(dictId, dict); maxApi.setDict(dictId, dict);
}); });
}); });
}); });
maxApi.addHandler("load", (dictId, key) => { maxApi.addHandler("load", (dictId, key) => {
maxApi.getDict(dictId).then((dict) => { maxApi.getDict(dictId).then((dict) => {
let data = dict[key]; let arch = dict[key].model;
tf.models.modelFromJSON(dict[key].model).then((m) => {
model = m
let data = dict[key].weights;
let tensors = []; let tensors = [];
data.forEach(item => { data.forEach(item => {
let shape = item.shape; let shape = item.shape;
@ -82,5 +89,6 @@ maxApi.addHandler("load", (dictId, key) => {
}); });
model.setWeights(tensors); model.setWeights(tensors);
}); });
});
}); });