loading configures saved architecture
This commit is contained in:
parent
9e14de36fb
commit
ea17b45f1c
18
tf.js
18
tf.js
@ -7,16 +7,15 @@ const inputShape = parseInt(args[0]);
|
|||||||
const outputShape = parseInt(args[1]);
|
const outputShape = parseInt(args[1]);
|
||||||
const hiddenSize = parseInt(args[2]);
|
const hiddenSize = parseInt(args[2]);
|
||||||
|
|
||||||
// Define a model for linear regression.
|
|
||||||
var model = tf.sequential();
|
var model = tf.sequential();
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
model.add(tf.layers.dense({units: hiddenSize, inputShape: [inputShape]}));
|
model.add(tf.layers.dense({units: hiddenSize, inputShape: [inputShape]}));
|
||||||
model.add(tf.layers.dense({units: outputShape}));
|
model.add(tf.layers.dense({units: outputShape}));
|
||||||
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
|
|
||||||
|
|
||||||
// Generate some synthetic data for training.
|
|
||||||
|
|
||||||
|
|
||||||
var xsArr = [];
|
var xsArr = [];
|
||||||
@ -24,6 +23,7 @@ var ysArr = [];
|
|||||||
|
|
||||||
|
|
||||||
maxApi.addHandler("train", (epochs) => {
|
maxApi.addHandler("train", (epochs) => {
|
||||||
|
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
|
||||||
|
|
||||||
// aggregate data
|
// aggregate data
|
||||||
const xs = tf.tensor2d(xsArr, [xsArr.length, inputShape]);
|
const xs = tf.tensor2d(xsArr, [xsArr.length, inputShape]);
|
||||||
@ -61,15 +61,22 @@ async function getWeights() {
|
|||||||
maxApi.addHandler("save", (dictId, key) => {
|
maxApi.addHandler("save", (dictId, key) => {
|
||||||
maxApi.getDict(dictId).then((dict) => {
|
maxApi.getDict(dictId).then((dict) => {
|
||||||
getWeights().then((weights) => {
|
getWeights().then((weights) => {
|
||||||
dict[key] = weights;
|
dict[key] = {}
|
||||||
|
dict[key].weights = weights;
|
||||||
|
dict[key].model = model.toJSON(null, false);
|
||||||
maxApi.setDict(dictId, dict);
|
maxApi.setDict(dictId, dict);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
||||||
maxApi.addHandler("load", (dictId, key) => {
|
maxApi.addHandler("load", (dictId, key) => {
|
||||||
maxApi.getDict(dictId).then((dict) => {
|
maxApi.getDict(dictId).then((dict) => {
|
||||||
let data = dict[key];
|
let arch = dict[key].model;
|
||||||
|
tf.models.modelFromJSON(dict[key].model).then((m) => {
|
||||||
|
model = m
|
||||||
|
|
||||||
|
let data = dict[key].weights;
|
||||||
let tensors = [];
|
let tensors = [];
|
||||||
data.forEach(item => {
|
data.forEach(item => {
|
||||||
let shape = item.shape;
|
let shape = item.shape;
|
||||||
@ -82,5 +89,6 @@ maxApi.addHandler("load", (dictId, key) => {
|
|||||||
});
|
});
|
||||||
model.setWeights(tensors);
|
model.setWeights(tensors);
|
||||||
});
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user