Difference between revisions of "MinPPO"

From Humanoid Robots Wiki
Jump to: navigation, search
(Created page with "These are notes for the MinPPO project [https://github.com/kscalelabs/minppo here]. == Testing == * Hidden layer size of 256 shows progress (loss is based on state.q[2]) * s...")
 
m
 
Line 9: Line 9:
 
* first thing to become nans seems to be actor loss and scores. after that, everything becomes nans
 
* first thing to become nans seems to be actor loss and scores. after that, everything becomes nans
 
* fixed entropy epsilon. hope this works now.
 
* fixed entropy epsilon. hope this works now.
 +
 +
== MNIST Training Example ==
 +
 +
<syntaxhighlight lang="python">
 +
"""Trains a simple MNIST model using Equinox."""
 +
 +
import logging
 +
 +
import equinox as eqx
 +
import jax
 +
import jax.numpy as jnp
 +
import optax
 +
from jax import random
 +
from tensorflow.keras.datasets import mnist
 +
 +
logging.basicConfig(
 +
    level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
 +
)
 +
logger = logging.getLogger(__name__)
 +
 +
 +
def load_mnist() -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
 +
    """Load and preprocess the MNIST dataset."""
 +
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
 +
    x_train = jnp.float32(x_train) / 255.0
 +
    x_test = jnp.float32(x_test) / 255.0
 +
    x_train = x_train.reshape(-1, 28 * 28)
 +
    x_test = x_test.reshape(-1, 28 * 28)
 +
    y_train = jax.nn.one_hot(y_train, 10)
 +
    y_test = jax.nn.one_hot(y_test, 10)
 +
    return x_train, y_train, x_test, y_test
 +
 +
 +
class DenseModel(eqx.Module):
 +
    """Define a simple dense neural network model."""
 +
 +
    layers: list
 +
 +
    def __init__(self, key: jnp.ndarray) -> None:
 +
        keys = random.split(key, 3)
 +
        self.layers = [
 +
            eqx.nn.Linear(28 * 28, 128, key=keys[0]),
 +
            eqx.nn.Linear(128, 64, key=keys[1]),
 +
            eqx.nn.Linear(64, 10, key=keys[2]),
 +
        ]
 +
 +
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
 +
        for layer in self.layers[:-1]:
 +
            x = jax.nn.relu(layer(x))
 +
        return self.layers[-1](x)
 +
 +
 +
@eqx.filter_value_and_grad
 +
def loss_fn(model: DenseModel, x_b: jnp.ndarray, y_b: jnp.ndarray) -> jnp.ndarray:
 +
    """Define the loss function (cross-entropy loss). Vecotrized across batch with vmap."""
 +
    pred_b = jax.vmap(model)(x_b)
 +
    return -jnp.mean(jnp.sum(y_b * jax.nn.log_softmax(pred_b), axis=-1))
 +
 +
 +
@eqx.filter_jit
 +
def make_step(
 +
    model: DenseModel,
 +
    optimizer: optax.GradientTransformation,
 +
    opt_state: optax.OptState,
 +
    x_b: jnp.ndarray,
 +
    y_b: jnp.ndarray,
 +
) -> tuple[jnp.ndarray, DenseModel, optax.OptState]:
 +
    """Perform a single optimization step."""
 +
    loss, grads = loss_fn(model, x_b, y_b)
 +
    updates, opt_state = optimizer.update(grads, opt_state)
 +
    model = eqx.apply_updates(model, updates)
 +
    return loss, model, opt_state
 +
 +
 +
def train(
 +
    model: DenseModel,
 +
    optimizer: optax.GradientTransformation,
 +
    x_train: jnp.ndarray,
 +
    y_train: jnp.ndarray,
 +
    batch_size: int,
 +
    num_epochs: int,
 +
) -> DenseModel:
 +
    """Train the model using the given data."""
 +
    opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
 +
 +
    for epoch in range(num_epochs):
 +
        for i in range(0, len(x_train), batch_size):
 +
            x_batch = x_train[i : i + batch_size]
 +
            y_batch = y_train[i : i + batch_size]
 +
            loss, model, opt_state = make_step(model, optimizer, opt_state, x_batch, y_batch)
 +
 +
        logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss:.4f}")
 +
 +
    return model
 +
 +
 +
@eqx.filter_jit
 +
def accuracy(model: DenseModel, x_b: jnp.ndarray, y_b: jnp.ndarray) -> jnp.ndarray:
 +
    """Takes in batch oftest images/label pairing with model and returns accuracy."""
 +
    pred = jax.vmap(model)(x_b)
 +
    return jnp.mean(jnp.argmax(pred, axis=-1) == jnp.argmax(y_b, axis=-1))
 +
 +
 +
def main() -> None:
 +
    # Load data
 +
    x_train, y_train, x_test, y_test = load_mnist()
 +
 +
    # Initialize model and optimizer
 +
    key = random.PRNGKey(0)
 +
    model = DenseModel(key)
 +
    optimizer = optax.adam(learning_rate=1e-3)
 +
 +
    # Train the model
 +
    batch_size = 32
 +
    num_epochs = 10
 +
    trained_model = train(model, optimizer, x_train, y_train, batch_size, num_epochs)
 +
 +
    test_accuracy = accuracy(trained_model, x_test, y_test)
 +
    logger.info(f"Test accuracy: {test_accuracy:.4f}")
 +
 +
 +
if __name__ == "__main__":
 +
    main()
 +
</syntaxhighlight>

Latest revision as of 22:21, 20 August 2024

These are notes for the MinPPO project here.

Testing[edit]

  • Hidden layer size of 256 shows progress (loss is based on state.q[2])
  • setting std to zero makes rewards nans why. I wonder if there NEEDS to be randomization in the enviornment
  • ctrl cost is whats giving nans? interesting?
  • it is unrelated to randomization of enviornmnet. i think gradient related
  • first thing to become nans seems to be actor loss and scores. after that, everything becomes nans
  • fixed entropy epsilon. hope this works now.

MNIST Training Example[edit]

"""Trains a simple MNIST model using Equinox."""

import logging

import equinox as eqx
import jax
import jax.numpy as jnp
import optax
from jax import random
from tensorflow.keras.datasets import mnist

logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
logger = logging.getLogger(__name__)


def load_mnist() -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Load and preprocess the MNIST dataset."""
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = jnp.float32(x_train) / 255.0
    x_test = jnp.float32(x_test) / 255.0
    x_train = x_train.reshape(-1, 28 * 28)
    x_test = x_test.reshape(-1, 28 * 28)
    y_train = jax.nn.one_hot(y_train, 10)
    y_test = jax.nn.one_hot(y_test, 10)
    return x_train, y_train, x_test, y_test


class DenseModel(eqx.Module):
    """Define a simple dense neural network model."""

    layers: list

    def __init__(self, key: jnp.ndarray) -> None:
        keys = random.split(key, 3)
        self.layers = [
            eqx.nn.Linear(28 * 28, 128, key=keys[0]),
            eqx.nn.Linear(128, 64, key=keys[1]),
            eqx.nn.Linear(64, 10, key=keys[2]),
        ]

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        for layer in self.layers[:-1]:
            x = jax.nn.relu(layer(x))
        return self.layers[-1](x)


@eqx.filter_value_and_grad
def loss_fn(model: DenseModel, x_b: jnp.ndarray, y_b: jnp.ndarray) -> jnp.ndarray:
    """Define the loss function (cross-entropy loss). Vecotrized across batch with vmap."""
    pred_b = jax.vmap(model)(x_b)
    return -jnp.mean(jnp.sum(y_b * jax.nn.log_softmax(pred_b), axis=-1))


@eqx.filter_jit
def make_step(
    model: DenseModel,
    optimizer: optax.GradientTransformation,
    opt_state: optax.OptState,
    x_b: jnp.ndarray,
    y_b: jnp.ndarray,
) -> tuple[jnp.ndarray, DenseModel, optax.OptState]:
    """Perform a single optimization step."""
    loss, grads = loss_fn(model, x_b, y_b)
    updates, opt_state = optimizer.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state


def train(
    model: DenseModel,
    optimizer: optax.GradientTransformation,
    x_train: jnp.ndarray,
    y_train: jnp.ndarray,
    batch_size: int,
    num_epochs: int,
) -> DenseModel:
    """Train the model using the given data."""
    opt_state = optimizer.init(eqx.filter(model, eqx.is_array))

    for epoch in range(num_epochs):
        for i in range(0, len(x_train), batch_size):
            x_batch = x_train[i : i + batch_size]
            y_batch = y_train[i : i + batch_size]
            loss, model, opt_state = make_step(model, optimizer, opt_state, x_batch, y_batch)

        logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss:.4f}")

    return model


@eqx.filter_jit
def accuracy(model: DenseModel, x_b: jnp.ndarray, y_b: jnp.ndarray) -> jnp.ndarray:
    """Takes in batch oftest images/label pairing with model and returns accuracy."""
    pred = jax.vmap(model)(x_b)
    return jnp.mean(jnp.argmax(pred, axis=-1) == jnp.argmax(y_b, axis=-1))


def main() -> None:
    # Load data
    x_train, y_train, x_test, y_test = load_mnist()

    # Initialize model and optimizer
    key = random.PRNGKey(0)
    model = DenseModel(key)
    optimizer = optax.adam(learning_rate=1e-3)

    # Train the model
    batch_size = 32
    num_epochs = 10
    trained_model = train(model, optimizer, x_train, y_train, batch_size, num_epochs)

    test_accuracy = accuracy(trained_model, x_test, y_test)
    logger.info(f"Test accuracy: {test_accuracy:.4f}")


if __name__ == "__main__":
    main()