loading configures saved architecture

This commit is contained in:
trian-gles 2024-06-28 09:11:01 +02:00
parent 9e14de36fb
commit ea17b45f1c

38
tf.js
View File

@ -7,16 +7,15 @@ const inputShape = parseInt(args[0]);
const outputShape = parseInt(args[1]); const outputShape = parseInt(args[1]);
const hiddenSize = parseInt(args[2]); const hiddenSize = parseInt(args[2]);
// Define a model for linear regression.
var model = tf.sequential(); var model = tf.sequential();
model.add(tf.layers.dense({units: hiddenSize, inputShape: [inputShape]})); model.add(tf.layers.dense({units: hiddenSize, inputShape: [inputShape]}));
model.add(tf.layers.dense({units: outputShape})); model.add(tf.layers.dense({units: outputShape}));
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
// Generate some synthetic data for training.
var xsArr = []; var xsArr = [];
@ -24,6 +23,7 @@ var ysArr = [];
maxApi.addHandler("train", (epochs) => { maxApi.addHandler("train", (epochs) => {
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
// aggregate data // aggregate data
const xs = tf.tensor2d(xsArr, [xsArr.length, inputShape]); const xs = tf.tensor2d(xsArr, [xsArr.length, inputShape]);
@ -61,26 +61,34 @@ async function getWeights() {
maxApi.addHandler("save", (dictId, key) => { maxApi.addHandler("save", (dictId, key) => {
maxApi.getDict(dictId).then((dict) => { maxApi.getDict(dictId).then((dict) => {
getWeights().then((weights) => { getWeights().then((weights) => {
dict[key] = weights; dict[key] = {}
dict[key].weights = weights;
dict[key].model = model.toJSON(null, false);
maxApi.setDict(dictId, dict); maxApi.setDict(dictId, dict);
}); });
}); });
}); });
maxApi.addHandler("load", (dictId, key) => { maxApi.addHandler("load", (dictId, key) => {
maxApi.getDict(dictId).then((dict) => { maxApi.getDict(dictId).then((dict) => {
let data = dict[key]; let arch = dict[key].model;
let tensors = []; tf.models.modelFromJSON(dict[key].model).then((m) => {
data.forEach(item => { model = m
let shape = item.shape;
let vals = [];
for (const [key, value] of Object.entries(item.data)) {
vals.push(value);
} let data = dict[key].weights;
tensors.push(tf.tensor(vals, shape)); let tensors = [];
}); data.forEach(item => {
model.setWeights(tensors); let shape = item.shape;
let vals = [];
for (const [key, value] of Object.entries(item.data)) {
vals.push(value);
}
tensors.push(tf.tensor(vals, shape));
});
model.setWeights(tensors);
});
}); });
}); });