tensorflow-js-maxmsp/tf.js
2024-06-27 13:09:51 +02:00

50 lines
1.2 KiB
JavaScript

const tf = require('@tensorflow/tfjs');
const maxApi = require('max-api');
const args = process.argv.slice(2)
const inputShape = parseInt(args[0]);
const outputShape = parseInt(args[1]);
const hiddenSize = parseInt(args[2]);
// Define a model for linear regression.
const model = tf.sequential();
model.add(tf.layers.dense({units: hiddenSize, inputShape: [inputShape]}));
model.add(tf.layers.dense({units: outputShape}));
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
// Generate some synthetic data for training.
var xsArr = [];
var ysArr = [];
maxApi.addHandler("train", (epochs) => {
// aggregate data
const xs = tf.tensor2d(xsArr, [xsArr.length, inputShape]);
const ys = tf.tensor2d(ysArr, [xsArr.length, outputShape]);
// Train the model using the data.
model.fit(xs, ys, {epochs});
});
maxApi.addHandler("dataPoint", (...data) => {
data.map((item) => parseFloat(item));
xsArr.push(data.slice(0, inputShape));
ysArr.push(data.slice(inputShape));
});
maxApi.addHandler("predict", (...data) => {
data.map((item) => parseFloat(item));
model.predict(tf.tensor2d([data], [1, inputShape])).array().then((value) => {
maxApi.outlet(value[0]);
});
});