117 lines
2.5 KiB
JavaScript
117 lines
2.5 KiB
JavaScript
const tf = require('@tensorflow/tfjs');
|
|
const maxApi = require('max-api');
|
|
|
|
|
|
const args = process.argv.slice(2)
|
|
const inputShape = parseInt(args[0]);
|
|
const outputShape = parseInt(args[1]);
|
|
const hiddenSize = parseInt(args[2]);
|
|
|
|
|
|
var model = tf.sequential();
|
|
|
|
|
|
|
|
model.add(tf.layers.dense({units: hiddenSize, inputShape: [inputShape], activation: 'relu'}));
|
|
model.add(tf.layers.dense({units: outputShape}));
|
|
|
|
|
|
|
|
|
|
var xsArr = [];
|
|
var ysArr = [];
|
|
|
|
|
|
maxApi.addHandler("train", (epochs) => {
|
|
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
|
|
|
|
// aggregate data
|
|
const xs = tf.tensor2d(xsArr, [xsArr.length, inputShape]);
|
|
const ys = tf.tensor2d(ysArr, [xsArr.length, outputShape]);
|
|
|
|
// Train the model using the data.
|
|
model.fit(xs, ys, {epochs});
|
|
|
|
maxApi.outlet("training_done");
|
|
});
|
|
|
|
maxApi.addHandler("data_point", (...data) => {
|
|
data.map((item) => parseFloat(item));
|
|
xsArr.push(data.slice(0, inputShape));
|
|
ysArr.push(data.slice(inputShape));
|
|
});
|
|
|
|
maxApi.addHandler("clear_data", () => {
|
|
xsArr = [];
|
|
ysArr = [];
|
|
});
|
|
|
|
maxApi.addHandler("dump_data", () => {
|
|
for (let i=0; i<xsArr.length; i++) {
|
|
maxApi.outlet(xsArr[i].concat(ysArr[i]));
|
|
}
|
|
});
|
|
|
|
|
|
maxApi.addHandler("predict", (...data) => {
|
|
data.map((item) => parseFloat(item));
|
|
model.predict(tf.tensor2d([data], [1, inputShape])).array().then((value) => {
|
|
maxApi.outlet(value[0]);
|
|
});
|
|
|
|
});
|
|
|
|
async function getWeights() {
|
|
let weights = model.getWeights();
|
|
for (let i = 0; i < weights.length; i++){
|
|
let data = await weights[i].data();
|
|
let shape = weights[i].shape;
|
|
weights[i] = {data, shape};
|
|
|
|
}
|
|
return weights;
|
|
}
|
|
|
|
maxApi.addHandler("save", (dictId, key) => {
|
|
maxApi.getDict(dictId).then((dict) => {
|
|
getJson().then((json) => {
|
|
dict[key] = json
|
|
maxApi.setDict(dictId, dict);
|
|
});
|
|
});
|
|
});
|
|
|
|
async function getJson() {
|
|
let json = {}
|
|
json.weights = await getWeights();
|
|
json.model = model.toJSON(null, false);
|
|
return json;
|
|
}
|
|
|
|
function loadWeights(dict){
|
|
tf.models.modelFromJSON(dict.model).then((m) => {
|
|
model = m
|
|
|
|
let data = dict.weights;
|
|
let tensors = [];
|
|
data.forEach(item => {
|
|
let shape = item.shape;
|
|
let vals = [];
|
|
for (const [key, value] of Object.entries(item.data)) {
|
|
vals.push(value);
|
|
|
|
}
|
|
tensors.push(tf.tensor(vals, shape));
|
|
});
|
|
model.setWeights(tensors);
|
|
});
|
|
}
|
|
|
|
|
|
maxApi.addHandler("set_weights", loadWeights);
|
|
|
|
maxApi.addHandler("dump_weights", () => {
|
|
getJson().then((json) => {maxApi.outlet("weights", json);});
|
|
})
|
|
|