hongbo-miao/hongbomiao.com

View on GitHub
high-performance-computing/hm-jax/src/main.py

Summary

Maintainability
A
35 mins
Test Coverage
import logging

import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax import grad, jit, random


# The linear model
def predict(w: float, b: float, X: jnp.ndarray) -> jnp.ndarray:
    return w * X + b


# The loss function (mean squared error)
def loss_fn(w: float, b: float, X: jnp.ndarray, y: jnp.ndarray) -> float:
    predictions = predict(w, b, X)
    return jnp.mean((predictions - y) ** 2)


# JIT compile the update step for efficiency
@jit
def update(
    w: float, b: float, X: jnp.ndarray, y: jnp.ndarray, learning_rate: float
) -> tuple[float, float]:
    dw, db = grad(loss_fn, argnums=(0, 1))(w, b, X, y)
    w -= learning_rate * dw
    b -= learning_rate * db
    return w, b


def main() -> None:
    # Generate synthetic data
    key = random.PRNGKey(0)
    N = 100  # number of data points
    X = random.normal(key, (N,))  # features

    # Adding even more non-linearity and varying noise
    true_w, true_b = 2.0, -1.0  # true weights for the linear model
    y = (
        true_w * X
        + true_b
        + jnp.sin(X) * 0.5
        + jnp.cos(2 * X) * 0.3  # additional non-linear component
        + random.normal(key, (N,)) * jnp.abs(X) * 0.5  # varying noise based on X
        + random.normal(key, (N,)) * 2.0  # additional noise for more randomness
    )

    # Initialize weights
    w = jnp.array(0.0)
    b = jnp.array(0.0)

    # Define the learning rate and the number of iterations
    learning_rate = 0.1
    num_iterations = 100

    # Train
    losses = []
    for i in range(num_iterations):
        w, b = update(w, b, X, y, learning_rate)
        current_loss = loss_fn(w, b, X, y)
        losses.append(current_loss)
        if i % 10 == 0:
            logging.info(
                f"Iteration {i}: loss = {current_loss:.4f}, w = {w:.4f}, b = {b:.4f}"
            )

    # Plot the results
    plt.plot(X, y, "bo", label="Data")
    plt.plot(X, predict(w, b, X), "r-", label="Fitted Line")
    plt.legend()
    plt.title(f"Linear Regression: w = {w:.2f}, b = {b:.2f}")
    plt.show()

    plt.plot(losses)
    plt.title("Loss over iterations")
    plt.xlabel("Iteration")
    plt.ylabel("Loss")
    plt.show()


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    main()