"""Agent for playing Connect 4.
This module defines the Connect 4 agent, which uses language models to generate moves
and analyze positions in the game.
"""
# Standard library imports
import copy
import logging
import time
from typing import Any
from haive.core.engine.agent.agent import register_agent
from langgraph.types import Command
from haive.games.connect4.config import Connect4AgentConfig
from haive.games.connect4.models import Connect4Move, Connect4PlayerDecision
from haive.games.connect4.state import Connect4State
from haive.games.connect4.state_manager import Connect4StateManager
from haive.games.framework.base.agent import GameAgent
# Third-party imports
# Local imports
logger = logging.getLogger(__name__)
[docs]
@register_agent(Connect4AgentConfig)
class Connect4Agent(GameAgent[Connect4AgentConfig]):
"""Agent for playing Connect 4.
This class implements the Connect 4 agent, which uses language models to generate
moves and analyze positions in the game.
"""
def __init__(self, config: Connect4AgentConfig):
""" Init .
Args:
config: [TODO: Add description]
"""
super().__init__(config)
# logger.debug("Engines initialized", extra={"engines": list(self.engines.keys())})
self.state_manager = Connect4StateManager
[docs]
def prepare_move_context(self, state: Connect4State, player: str) -> dict[str, Any]:
"""Prepare context for move generation.
This method prepares the context for move generation by formatting the legal
moves and getting the player's last analysis.
"""
legal_moves = self.state_manager.get_legal_moves(state)
formatted_legal_moves = [f"Column {move.column}" for move in legal_moves]
# Get player's last analysis (if available)
player_analysis = None
if player == "red" and state.red_analysis:
player_analysis = state.red_analysis[-10:]
elif player == "yellow" and state.yellow_analysis:
player_analysis = state.yellow_analysis[-10:]
# Calculate threats
threats = self._calculate_threats(state, player)
# ✅ Extract winning/blocking threats explicitly
return {
"board": state.board_string,
"turn": state.turn,
"color": player,
"legal_moves": formatted_legal_moves,
"move_history": [
f"{i + 1}. {state.turn}: Column {move.column}"
for i, move in enumerate(state.move_history[-5:])
],
"player_analysis": player_analysis,
# Explicitly named
"threats_winning_moves": threats["winning_moves"],
# Explicitly named
"threats_blocking_moves": threats["blocking_moves"],
}
def _calculate_threats(
self, state: Connect4State, player: str
) -> dict[str, list[int]]:
"""Calculate immediate threats and opportunities.
This method calculates the immediate threats and opportunities for the given
player in the current game state.
"""
opponent = "yellow" if player == "red" else "red"
player_winning_moves = []
opponent_winning_moves = []
for col in range(7):
if state.is_column_full(col):
continue
row = state.get_next_row(col)
# ✅ Check if this move wins for the player
temp_state = copy.deepcopy(state)
temp_state.board[row][col] = player
if self.state_manager._check_win(temp_state, row, col):
player_winning_moves.append(col)
# ✅ Check if this move blocks opponent's win
temp_state = copy.deepcopy(state)
temp_state.board[row][col] = opponent
if self.state_manager._check_win(temp_state, row, col):
opponent_winning_moves.append(col)
return {
"winning_moves": player_winning_moves,
"blocking_moves": opponent_winning_moves,
}
[docs]
def prepare_analysis_context(
self, state: Connect4State, player: str
) -> dict[str, Any]:
"""Prepare context for position analysis with correct variables.
This method prepares the context for position analysis by calculating threats
and formatting the required fields.
"""
threats = self._calculate_threats(state, player)
# ✅ Ensure all required fields exist
columns_usage = [
sum(1 for row in state.board if row[col] is not None) for col in range(7)
]
threats_winning_moves = threats.get("winning_moves", [])
threats_blocking_moves = threats.get("blocking_moves", [])
return {
"board": state.board_string,
"turn": state.turn,
"color": player,
"move_history": [
f"{i + 1}. {state.turn}: Column {move.column}"
for i, move in enumerate(state.move_history[-5:])
],
"threats_winning_moves": threats_winning_moves,
"threats_blocking_moves": threats_blocking_moves,
"columns_usage": columns_usage, # ✅ Now the model gets what it expects!
}
[docs]
def make_player1_move(self, state: Connect4State) -> Command:
"""Make a move for the red player.
This method makes a move for the red player in the current game state.
"""
return self.make_move(state, "red")
[docs]
def make_player2_move(self, state: Connect4State) -> Command:
"""Make a move for the yellow player.
This method makes a move for the yellow player in the current game state.
"""
return self.make_move(state, "yellow")
[docs]
def analyze_player1(self, state: Connect4State) -> Command:
"""Analyze position for the red player.
This method analyzes the position for the red player in the current game state.
"""
return self.analyze_position(state, "red")
[docs]
def analyze_player2(self, state: Connect4State) -> Command:
"""Analyze position for the yellow player.
This method analyzes the position for the yellow player in the current game
state.
"""
return self.analyze_position(state, "yellow")
[docs]
def visualize_state(self, state: dict[str, Any]) -> None:
"""Visualize the current game state with better formatting and insights.
This method visualizes the current game state with better formatting and
insights. It displays the current player, game status, board, last move, and
analysis from the previous turn.
"""
# Create a Connect4State from the dict
connect4_state = Connect4State(**state)
logger.info(
"Game state visualization",
extra={
"separator": "=" * 60,
"current_player": connect4_state.turn.upper(),
"game_status": connect4_state.game_status.upper(),
"board": connect4_state.board_string,
},
)
# 📝 Show last move if available
if connect4_state.move_history:
last_move = connect4_state.move_history[-1]
player = "red" if len(connect4_state.move_history) % 2 == 1 else "yellow"
logger.info(
"Last move information",
extra={"player": player.upper(), "column": last_move.column},
)
# 🔍 **Show analysis from the previous turn**
analysis = None
analysis_color = None
if connect4_state.red_analysis and connect4_state.turn == "yellow":
analysis = connect4_state.red_analysis[-1]
analysis_color = "Red"
elif connect4_state.yellow_analysis and connect4_state.turn == "red":
analysis = connect4_state.yellow_analysis[-1]
analysis_color = "Yellow"
if analysis:
threats = analysis.get("threats", {})
logger.info(
"Player analysis",
extra={
"analysis_color": analysis_color,
"position_score": analysis.get("position_score", "N/A"),
"center_control": f"{analysis.get('center_control', 'N/A')}/10",
"suggested_columns": ", ".join(
map(str, analysis.get("suggested_columns", []))
),
"winning_chances": f"{analysis.get('winning_chances', 'N/A')}%",
"threats": {
"winning_moves": ", ".join(
map(str, threats.get("winning_moves", []))
)
or "None",
"blocking_moves": ", ".join(
map(str, threats.get("blocking_moves", []))
)
or "None",
},
},
)
# Short delay for readability
time.sleep(0.5)