An attention-based Deep Q-Network agent for playing Repeated Rock-Paper-Scissors (RRPS) using PyTorch and OpenSpiel.
This repo implements an agent to train a deep neural network model for Repeated Rock-Paper-Scissors against a pool of bots. The agent leverages a attention neural network to process long action histories and learns to adapt to various opponent strategies using prioritized experience replay.
The agent uses a custom PyTorch model defined in model.py:
- RPSDualHeadAttention:
- Input: Sequence of one-hot encoded action pairs (agent and opponent) over a configurable context window.
- Layers:
- Linear projection from input dimension (
2 * num_actions) to a hidden dimension. - Multi-head self-attention (
nn.MultiheadAttention) to capture temporal dependencies and patterns in the action history. - Output head: Linear layer producing Q-values for each possible action.
- Linear projection from input dimension (
- Output: Q-values for the agent's possible actions at the current timestep.
- PrioritizedReplayBuffer:
- Stores transitions as
(state, action, opponent_action, reward, next_state, done). - Assigns priorities based on the temporal-difference (TD) error.
- Samples batches with probability proportional to priority, focusing learning on surprising or informative experiences.
- Supports importance sampling weights to correct for the bias introduced by prioritized sampling.
- Stores transitions as
- Environment: Uses OpenSpiel's repeated RPS environment.
- Agent:
TrainingAgentwraps the DQN logic, model, optimizer, and replay buffer. - Opponent Pool: Randomly samples from a set of bots each episode.
- Step:
- Agent and opponent select actions.
- Environment returns next state and reward.
- Transition is stored in the replay buffer.
- Periodically, a batch is sampled and the model is updated using DQN loss.
- Target network is updated at fixed intervals.
- Checkpointing: Model weights are saved at regular intervals and at the end of training.
-
Clone the repository:
git clone https://github.com/yourusername/dqn-rps-agent.git cd dqn-rps-agent -
Install dependencies:
- Python 3.8+
- PyTorch
- OpenSpiel (Python bindings)
- numpy
You can install dependencies via pip:
pip install torch numpy open_spiel
Train the agent with default settings:
python train.py--episodes: Number of training episodes (default: 100000)--recall: Length of action history (default: 20)--save_path: Path to save model checkpoints--checkpoint_every: Save checkpoint every N episodes--log_every: Print logs every N episodes--temperature: Action selection temperature--opp_loss_weight: Weight for opponent modeling loss (if applicable)--seed: Random seed
Example:
python train.py --episodes 50000 --recall 50 --save_path my_agent.pthAll hyperparameters can be adjusted in train.py via the TrainConfig dataclass or command-line arguments. Notable parameters:
hidden_dim: Hidden size for the attention model.context_len: Number of past action pairs to consider.buffer_size: Replay buffer capacity.batch_size: Training batch size.gamma: Discount factor.lr: Learning rate.per_alpha,per_beta_start,per_beta_frames,per_eps: PER parameters.