AI/ML March 5, 2024 10 min read

Introduction to Machine Learning with JavaScript

Explore the fundamentals of machine learning using JavaScript and TensorFlow.js. Build your first neural network and understand core ML concepts.

Nikesh Bhattarai
Nikesh Bhattarai
Backend Developer & AI/ML Engineer
Machine Learning with JavaScript

Introduction

Machine Learning in JavaScript has become increasingly accessible thanks to libraries like TensorFlow.js. This guide will walk you through the fundamentals of ML and help you build your first neural network entirely in JavaScript. Whether you're a web developer looking to add ML capabilities to your applications or just curious about ML, this tutorial is perfect for you.

We'll cover everything from basic ML concepts to building and training a neural network for image classification, all using JavaScript. By the end of this guide, you'll have a solid understanding of how ML works in the browser and Node.js environments.

What is Machine Learning?

Machine Learning is a subset of artificial intelligence that enables computers to learn and improve from experience without being explicitly programmed. Instead of writing rules to solve a problem, you provide data and let the algorithm learn the patterns.

// Traditional Programming
function predictPrice(size, location, bedrooms) {
  // Hard-coded rules
  if (location === 'downtown') {
    return size * 200 + bedrooms * 50000;
  } else {
    return size * 150 + bedrooms * 30000;
  }
}

// Machine Learning Approach
const model = await trainModel(trainingData);
const prediction = await model.predict([size, location, bedrooms]);
// Model learns patterns from data automatically

Types of Machine Learning

Supervised Learning

Learning from labeled data where the correct answers are provided:

// Example: Email classification
const trainingData = [
  { text: "Buy now! Limited offer!", label: "spam" },
  { text: "Meeting tomorrow at 3pm", label: "not_spam" },
  { text: "Congratulations! You won!", label: "spam" },
  { text: "Project update attached", label: "not_spam" }
];

// Model learns to classify emails based on examples
const model = await trainClassifier(trainingData);

Unsupervised Learning

Finding patterns in unlabeled data:

// Example: Customer segmentation
const customerData = [
  { age: 25, income: 50000, spending: 2000 },
  { age: 45, income: 80000, spending: 5000 },
  { age: 30, income: 60000, spending: 3000 }
];

// Model groups similar customers together
const clusters = await findClusters(customerData);

Reinforcement Learning

Learning through trial and error with rewards and penalties:

// Example: Game playing AI
const agent = new GameAgent();
let score = 0;

while (!gameOver) {
  const action = agent.chooseAction(gameState);
  const reward = game.makeMove(action);
  agent.learn(action, reward, newState);
  score += reward;
}

Getting Started with TensorFlow.js

TensorFlow.js is Google's open-source library for machine learning in JavaScript. Let's set up a basic project:

// Install TensorFlow.js
npm install @tensorflow/tfjs

// Import in your project
import * as tf from '@tensorflow/tfjs';

// Check if TensorFlow.js is loaded
console.log('TensorFlow.js version:', tf.version.tfjs);

// Basic tensor operations
const tensor1 = tf.tensor([1, 2, 3, 4]);
const tensor2 = tf.tensor([5, 6, 7, 8]);

const sum = tensor1.add(tensor2);
sum.print(); // [6, 8, 10, 12]

Building Your First Neural Network

Let's build a simple neural network for binary classification:

// Create a sequential model
const model = tf.sequential();

// Add layers to the model
model.add(tf.layers.dense({
  units: 32,
  activation: 'relu',
  inputShape: [4] // 4 input features
}));

model.add(tf.layers.dense({
  units: 16,
  activation: 'relu'
}));

model.add(tf.layers.dense({
  units: 1,
  activation: 'sigmoid' // Binary classification
}));

// Compile the model
model.compile({
  optimizer: 'adam',
  loss: 'binaryCrossentropy',
  metrics: ['accuracy']
});

// Display model summary
model.summary();

Preparing Data

Data preparation is crucial for ML success. Let's prepare some sample data:

// Sample training data (e.g., flower classification)
const trainingData = {
  inputs: tf.tensor2d([
    [5.1, 3.5, 1.4, 0.2], // Iris setosa
    [4.9, 3.0, 1.4, 0.2], // Iris setosa
    [7.0, 3.2, 4.7, 1.4], // Iris versicolor
    [6.4, 3.2, 4.5, 1.5], // Iris versicolor
    [6.3, 3.3, 6.0, 2.5], // Iris virginica
    [5.8, 2.7, 5.1, 1.9]  // Iris virginica
  ]),
  
  outputs: tf.tensor2d([
    [1, 0, 0], // Iris setosa
    [1, 0, 0], // Iris setosa
    [0, 1, 0], // Iris versicolor
    [0, 1, 0], // Iris versicolor
    [0, 0, 1], // Iris virginica
    [0, 0, 1]  // Iris virginica
  ])
};

// Normalize the data (important for neural networks)
const normalizedInputs = trainingData.inputs.div(tf.max(trainingData.inputs));

Training the Model

Now let's train our neural network with the prepared data:

// Training configuration
const trainConfig = {
  epochs: 50,
  batchSize: 2,
  validationSplit: 0.2,
  shuffle: true,
  callbacks: {
    onEpochEnd: (epoch, logs) => {
      console.log(`Epoch ${epoch + 1}: loss = ${logs.loss.toFixed(4)}, accuracy = ${logs.acc.toFixed(4)}`);
    }
  }
};

// Train the model
async function trainModel() {
  const history = await model.fit(
    normalizedInputs,
    trainingData.outputs,
    trainConfig
  );
  
  console.log('Training completed!');
  return history;
}

// Start training
trainModel().then(history => {
  console.log('Final accuracy:', history.history.acc[history.history.acc.length - 1]);
});

Making Predictions

Once the model is trained, we can use it to make predictions:

// New data to predict
const newData = tf.tensor2d([
  [5.0, 3.6, 1.4, 0.2], // Should be Iris setosa
  [6.7, 3.0, 5.2, 2.3]  // Should be Iris virginica
]);

// Normalize the new data using the same parameters
const normalizedNewData = newData.div(tf.max(trainingData.inputs));

// Make predictions
const predictions = model.predict(normalizedNewData);
const predictionData = await predictions.array();

// Interpret results
const classes = ['Iris Setosa', 'Iris Versicolor', 'Iris Virginica'];

predictionData.forEach((prediction, index) => {
  const maxIndex = prediction.indexOf(Math.max(...prediction));
  const confidence = Math.max(...prediction) * 100;
  console.log(`Sample ${index + 1}: ${classes[maxIndex]} (${confidence.toFixed(2)}% confidence)`);
});

Image Classification in Browser

TensorFlow.js shines in browser-based ML applications. Here's how to classify images:

// Load a pre-trained model (MobileNet)
async function loadModel() {
  const model = await tf.loadLayersModel('https://tfhub.dev/google/tfjs-model/imagenet/mobilenet_v2_100_224/classification/3/default/1');
  return model;
}

// Preprocess image for model
function preprocessImage(imageElement) {
  const tensor = tf.browser.fromPixels(imageElement)
    .resizeBilinear([224, 224])
    .toFloat()
    .div(tf.scalar(127.5))
    .sub(tf.scalar(1))
    .expandDims();
  
  return tensor;
}

// Classify image
async function classifyImage(imageElement, model) {
  const preprocessed = preprocessImage(imageElement);
  const predictions = await model.predict(preprocessed).data();
  
  // Get top 5 predictions
  const top5 = Array.from(predictions)
    .map((probability, index) => ({
      probability: probability,
      className: IMAGENET_CLASSES[index]
    }))
    .sort((a, b) => b.probability - a.probability)
    .slice(0, 5);
  
  return top5;
}

// Usage in HTML
document.getElementById('image-input').addEventListener('change', async (event) => {
  const file = event.target.files[0];
  const imageElement = document.getElementById('preview');
  
  imageElement.src = URL.createObjectURL(file);
  
  const model = await loadModel();
  const predictions = await classifyImage(imageElement, model);
  
  displayPredictions(predictions);
});

Real-time Object Detection

Create real-time object detection using webcam:

// Setup webcam
async function setupWebcam() {
  const video = document.getElementById('webcam');
  const stream = await navigator.mediaDevices.getUserMedia({ video: true });
  video.srcObject = stream;
  
  return new Promise((resolve) => {
    video.onloadedmetadata = () => resolve(video);
  });
}

// Load COCO-SSD model for object detection
async function loadObjectDetectionModel() {
  const model = await cocoSsd.load();
  return model;
}

// Detect objects in real-time
async function detectObjects() {
  const video = document.getElementById('webcam');
  const model = await loadObjectDetectionModel();
  const canvas = document.getElementById('canvas');
  const ctx = canvas.getContext('2d');
  
  canvas.width = video.videoWidth;
  canvas.height = video.videoHeight;
  
  async function detect() {
    const predictions = await model.detect(video);
    
    // Clear canvas
    ctx.clearRect(0, 0, canvas.width, canvas.height);
    
    // Draw bounding boxes
    predictions.forEach(prediction => {
      const [x, y, width, height] = prediction.bbox;
      
      // Draw box
      ctx.strokeStyle = '#00FF00';
      ctx.lineWidth = 2;
      ctx.strokeRect(x, y, width, height);
      
      // Draw label
      ctx.fillStyle = '#00FF00';
      ctx.fillText(
        `${prediction.class} (${Math.round(prediction.score * 100)}%)`,
        x,
        y > 10 ? y - 5 : 10
      );
    });
    
    requestAnimationFrame(detect);
  }
  
  detect();
}

// Start detection
setupWebcam().then(detectObjects);

Transfer Learning

Use pre-trained models and adapt them for your specific tasks:

// Load MobileNet and remove the top layer
const baseModel = await tf.loadLayersModel(
  'https://tfhub.dev/google/tfjs-model/imagenet/mobilenet_v2_100_224/feature_vector/3/default/1'
);

// Freeze the base model layers
baseModel.trainable = false;

// Create a new model on top
const model = tf.sequential([
  baseModel,
  tf.layers.dense({ units: 128, activation: 'relu' }),
  tf.layers.dropout({ rate: 0.2 }),
  tf.layers.dense({ units: numClasses, activation: 'softmax' })
]);

// Compile with a lower learning rate for fine-tuning
model.compile({
  optimizer: tf.train.adam(0.0001),
  loss: 'categoricalCrossentropy',
  metrics: ['accuracy']
});

// Train on your custom dataset
const history = await model.fit(customDataset, {
  epochs: 10,
  batchSize: 32,
  validationSplit: 0.2
});

Best Practices

Data Preparation

  • Always normalize your data before training
  • Split data into training, validation, and test sets
  • Augment data to increase dataset size
  • Handle missing values appropriately

Model Architecture

  • Start with simple models and gradually increase complexity
  • Use appropriate activation functions
  • Implement dropout to prevent overfitting
  • Consider batch normalization for better training

Performance Optimization

  • Use Web Workers for heavy computations
  • Implement model quantization for smaller file sizes
  • Cache preprocessed data when possible
  • Use GPU acceleration when available

Conclusion

Machine Learning with JavaScript has opened up incredible possibilities for web developers. From simple classification tasks to complex real-time object detection, TensorFlow.js provides the tools needed to build sophisticated ML applications entirely in JavaScript.

The key to success is understanding the fundamentals, starting with simple projects, and gradually building up complexity. Remember that ML is an iterative process - experiment with different architectures, hyperparameters, and techniques to find what works best for your specific use case.

Related Articles

React Performance Optimization Techniques

Advanced techniques for optimizing React apps.

Read More →

Building Scalable APIs with Node.js

Learn how to architect robust REST APIs.

Read More →

Ready to dive deeper into ML?

Get advanced ML tutorials and JavaScript tips delivered weekly.