Skip to content

UltimateBoomer/attn-dqn-rps-agent

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Attention-DQN RPS Agent

An attention-based Deep Q-Network agent for playing Repeated Rock-Paper-Scissors (RRPS) using PyTorch and OpenSpiel.

Overview

This repo implements an agent to train a deep neural network model for Repeated Rock-Paper-Scissors against a pool of bots. The agent leverages a attention neural network to process long action histories and learns to adapt to various opponent strategies using prioritized experience replay.

Technical Details

Model Architecture

The agent uses a custom PyTorch model defined in model.py:

  • RPSDualHeadAttention:
    • Input: Sequence of one-hot encoded action pairs (agent and opponent) over a configurable context window.
    • Layers:
      • Linear projection from input dimension (2 * num_actions) to a hidden dimension.
      • Multi-head self-attention (nn.MultiheadAttention) to capture temporal dependencies and patterns in the action history.
      • Output head: Linear layer producing Q-values for each possible action.
    • Output: Q-values for the agent's possible actions at the current timestep.

Experience Replay

  • PrioritizedReplayBuffer:
    • Stores transitions as (state, action, opponent_action, reward, next_state, done).
    • Assigns priorities based on the temporal-difference (TD) error.
    • Samples batches with probability proportional to priority, focusing learning on surprising or informative experiences.
    • Supports importance sampling weights to correct for the bias introduced by prioritized sampling.

Training Loop

  • Environment: Uses OpenSpiel's repeated RPS environment.
  • Agent: TrainingAgent wraps the DQN logic, model, optimizer, and replay buffer.
  • Opponent Pool: Randomly samples from a set of bots each episode.
  • Step:
    1. Agent and opponent select actions.
    2. Environment returns next state and reward.
    3. Transition is stored in the replay buffer.
    4. Periodically, a batch is sampled and the model is updated using DQN loss.
    5. Target network is updated at fixed intervals.
  • Checkpointing: Model weights are saved at regular intervals and at the end of training.

Installation

  1. Clone the repository:

    git clone https://github.com/yourusername/dqn-rps-agent.git
    cd dqn-rps-agent
  2. Install dependencies:

    • Python 3.8+
    • PyTorch
    • OpenSpiel (Python bindings)
    • numpy

    You can install dependencies via pip:

    pip install torch numpy open_spiel

Usage

Train the agent with default settings:

python train.py

Command-Line Arguments

  • --episodes: Number of training episodes (default: 100000)
  • --recall: Length of action history (default: 20)
  • --save_path: Path to save model checkpoints
  • --checkpoint_every: Save checkpoint every N episodes
  • --log_every: Print logs every N episodes
  • --temperature: Action selection temperature
  • --opp_loss_weight: Weight for opponent modeling loss (if applicable)
  • --seed: Random seed

Example:

python train.py --episodes 50000 --recall 50 --save_path my_agent.pth

Configuration

All hyperparameters can be adjusted in train.py via the TrainConfig dataclass or command-line arguments. Notable parameters:

  • hidden_dim: Hidden size for the attention model.
  • context_len: Number of past action pairs to consider.
  • buffer_size: Replay buffer capacity.
  • batch_size: Training batch size.
  • gamma: Discount factor.
  • lr: Learning rate.
  • per_alpha, per_beta_start, per_beta_frames, per_eps: PER parameters.

About

An attention-based Deep Q-Network agent for playing Repeated Rock-Paper-Scissors (RRPS) using PyTorch and OpenSpiel.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages