Skip to content

Latest commit

 

History

History
128 lines (92 loc) · 4.04 KB

File metadata and controls

128 lines (92 loc) · 4.04 KB

TensorflowJS in React Native

Volledige code voor TensorflowJS in React Native.

  • Inlezen CSV
  • Trainen van een model
  • Doe een voorspelling

COMPONENT

Het component laadt de CSV en roept de training functie aan. Je kan de predict functie aanroepen met een button.

import * as React from 'react';
import { Text, View, Button, StyleSheet, TextInput } from 'react-native';
import { readRemoteFile, readString } from 'react-native-csv';
import * as tf from '@tensorflow/tfjs'
import { createNeuralNetwork } from "./createNeuralNetwork.js"
import carsData from '../assets/cars.csv';

export default function TFExample() {
  const [prediction, setPrediction] = React.useState(0);
  const [hp, onChangeHP] = React.useState(50);
  const [weight, onChangeWeight] = React.useState(2000);
  const machine = React.useRef();
  const normalizeValues = React.useRef()


  //
  // prediction
  //
  const makePrediction = () => {
      // let op dat hp en weight numbers zijn en geen strings
      const userInput = tf.tensor2d([[hp,weight]])
      // normalize the user input, un-normalise the prediction
      const nc = normalizeValues.current
      const normInput = userInput.sub(nc.inputMin).div(nc.inputMax.sub(nc.inputMin))
      const predictionTF = machine.current.predict(normInput)
      const unnormalisedPredictionTF = predictionTF.mul(nc.labelMax.sub(nc.labelMin)).add(nc.labelMin)
      const prediction = Math.round(unnormalisedPredictionTF.dataSync()[0])

      setPrediction(prediction)
  }

  //
  // load csv
  //
  React.useEffect(() => {
    const loadCSV = () => {
      readRemoteFile(carsData, {
        download: true,
        header: true,
        dynamicTyping: true,
        complete: (results) => {
          csvLoaded(results.data);
        },
      });
    };

    const csvLoaded = async(data) => {
        data.sort(() => (Math.random() - 0.5))

        const inputs = data.map(d => [d.horsepower, d.weight])
        const outputs = data.map(d => d.mpg)
        const [model, normValues] = await createNeuralNetwork(inputs, outputs)
        
        machine.current = model
        normalizeValues.current = normValues
    };

    loadCSV();
  }, []);
}




Het Neural Network bouwen

Voor het overzicht is het bouwen van het neural network en het normalizen van de data in een eigen JS bestand geplaatst. Deze moet je importeren in het component.

import * as tf from '@tensorflow/tfjs'

export const createNeuralNetwork = async (inputs, outputs) => {
    // tensors maken van de inputs en outputs, let op dat alle data NUMBERS zijn!
    const inputTensor = tf.tensor2d(inputs)
    const labelTensor = tf.tensor1d(outputs)

    const [inputMax, inputMin, labelMax, labelMin] = [inputTensor.max(0,false), inputTensor.min(0,false), labelTensor.max(), labelTensor.min()]
    const normalizedInputs = inputTensor.sub(inputMin).div(inputMax.sub(inputMin))
    const normalizedLabels = labelTensor.sub(labelMin).div(labelMax.sub(labelMin))

    // bouw het model met 2 features: horsepower, weight
    const numFeatures = inputs[0].length
    const model = tf.sequential()
    model.add(tf.layers.dense({ units: 8, inputShape: [numFeatures] }))
    model.add(tf.layers.dense({ units: 1 }))
    model.compile({ loss: 'meanSquaredError', optimizer: 'sgd' })
    // aantal epochs instellen
    await model.fit(normalizedInputs, normalizedLabels, { epochs: 5 })
    return [model, { inputMin, inputMax, labelMin, labelMax }]    
}




Training visualisatie

Je kan de visualisatietool van tensorflowJS niet gebruiken in React Native omdat de tool verwacht dat er een browser venster is. Als het niet nodig is om live te trainen met gebruikersinput zou je het model al vantevoren kunnen maken in de browser.




Model opslaan

Je kan het model opslaan in de AsyncStorage van React Native. De tfjs-react-native library heeft hier een helper functie voor: