loading configures saved architecture
This commit is contained in:
parent
9e14de36fb
commit
ea17b45f1c
18
tf.js
18
tf.js
@ -7,16 +7,15 @@ const inputShape = parseInt(args[0]);
|
||||
const outputShape = parseInt(args[1]);
|
||||
const hiddenSize = parseInt(args[2]);
|
||||
|
||||
// Define a model for linear regression.
|
||||
|
||||
var model = tf.sequential();
|
||||
|
||||
|
||||
|
||||
model.add(tf.layers.dense({units: hiddenSize, inputShape: [inputShape]}));
|
||||
model.add(tf.layers.dense({units: outputShape}));
|
||||
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
|
||||
|
||||
// Generate some synthetic data for training.
|
||||
|
||||
|
||||
|
||||
var xsArr = [];
|
||||
@ -24,6 +23,7 @@ var ysArr = [];
|
||||
|
||||
|
||||
maxApi.addHandler("train", (epochs) => {
|
||||
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
|
||||
|
||||
// aggregate data
|
||||
const xs = tf.tensor2d(xsArr, [xsArr.length, inputShape]);
|
||||
@ -61,15 +61,22 @@ async function getWeights() {
|
||||
maxApi.addHandler("save", (dictId, key) => {
|
||||
maxApi.getDict(dictId).then((dict) => {
|
||||
getWeights().then((weights) => {
|
||||
dict[key] = weights;
|
||||
dict[key] = {}
|
||||
dict[key].weights = weights;
|
||||
dict[key].model = model.toJSON(null, false);
|
||||
maxApi.setDict(dictId, dict);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
maxApi.addHandler("load", (dictId, key) => {
|
||||
maxApi.getDict(dictId).then((dict) => {
|
||||
let data = dict[key];
|
||||
let arch = dict[key].model;
|
||||
tf.models.modelFromJSON(dict[key].model).then((m) => {
|
||||
model = m
|
||||
|
||||
let data = dict[key].weights;
|
||||
let tensors = [];
|
||||
data.forEach(item => {
|
||||
let shape = item.shape;
|
||||
@ -83,4 +90,5 @@ maxApi.addHandler("load", (dictId, key) => {
|
||||
model.setWeights(tensors);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user