import pandas as pd
import numpy as np
import random
import math
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

# User Inputs
csv_file      = "points.csv"
n_population  = 150
n_generations = 500
rate_mutation = 0.10
n_elite       = 10
# ----- ----- ----- ----- ----- ----- ---- ----- ----- ----- ----- -----
# Load points from CSV
# ----- ----- ----- ----- ----- ----- ---- ----- ----- ----- ----- -----
def load_points(csv_file):
  df = pd.read_csv(csv_file)
  return df[['x', 'y']].values


# ----- ----- ----- ----- ----- ----- ---- ----- ----- ----- ----- -----
# Fitness Function
# Lower cumulative distance is better
# ----- ----- ----- ----- ----- ----- ---- ----- ----- ----- ----- -----
def total_distance(candidate, points):
  x, y = candidate

  distance_sum = 0.0
  for px, py in points:
    distance_sum += math.sqrt((x - px)**2 + (y - py)**2)

  return distance_sum


# ----- ----- ----- ----- ----- ----- ---- ----- ----- ----- ----- -----
# Create Initial Population
# ----- ----- ----- ----- ----- ----- ---- ----- ----- ----- ----- -----
def create_population(pop_size, xmin, xmax, ymin, ymax):
  population = []

  for _ in range(pop_size):
    x = random.uniform(xmin, xmax)
    y = random.uniform(ymin, ymax)

    population.append([x, y])

  return population


# ----- ----- ----- ----- ----- ----- ---- ----- ----- ----- ----- -----
# Tournament Selection
# ----- ----- ----- ----- ----- ----- ---- ----- ----- ----- ----- -----
def selection(population, fitnesses, tournament_size=3):

  selected = random.sample(
    list(zip(population, fitnesses)),
    tournament_size
  )

  selected.sort(key=lambda x: x[1])

  return selected[0][0]


# ----- ----- ----- ----- ----- ----- ---- ----- ----- ----- ----- -----
# Crossover
# ----- ----- ----- ----- ----- ----- ---- ----- ----- ----- ----- -----
def crossover(parent1, parent2):

  alpha = random.random()

  child_x = alpha * parent1[0] + (1 - alpha) * parent2[0]
  child_y = alpha * parent1[1] + (1 - alpha) * parent2[1]

  return [child_x, child_y]


# ----- ----- ----- ----- ----- ----- ---- ----- ----- ----- ----- -----
# Mutation
# ----- ----- ----- ----- ----- ----- ---- ----- ----- ----- ----- -----
def mutate(individual, mutation_rate,
       xmin, xmax, ymin, ymax):

  if random.random() < mutation_rate:
    individual[0] += random.gauss(0, (xmax - xmin) * 0.05)

  if random.random() < mutation_rate:
    individual[1] += random.gauss(0, (ymax - ymin) * 0.05)

  individual[0] = max(xmin, min(xmax, individual[0]))
  individual[1] = max(ymin, min(ymax, individual[1]))

  return individual


# ----- ----- ----- ----- ----- ----- ---- ----- ----- ----- ----- -----
# Genetic Algorithm
# ----- ----- ----- ----- ----- ----- ---- ----- ----- ----- ----- -----
def genetic_algorithm(
    points,
    population_size=100,
    generations=300,
    mutation_rate=0.15,
    elite_size=5):

  xmin = np.min(points[:, 0])
  xmax = np.max(points[:, 0])

  ymin = np.min(points[:, 1])
  ymax = np.max(points[:, 1])

  # Add margin
  margin_x = (xmax - xmin) * 0.2
  margin_y = (ymax - ymin) * 0.2

  xmin -= margin_x
  xmax += margin_x
  ymin -= margin_y
  ymax += margin_y

  population = create_population(
    population_size,
    xmin, xmax,
    ymin, ymax
  )

  best_solution = None
  best_fitness = float('inf')

  for generation in range(generations):

    fitnesses = [
      total_distance(ind, points)
      for ind in population
    ]

    ranked = sorted(
      zip(population, fitnesses),
      key=lambda x: x[1]
    )

    # Store best
    if ranked[0][1] < best_fitness:
      best_solution = ranked[0][0]
      best_fitness = ranked[0][1]

    # Elitism
    new_population = [
      individual.copy()
      for individual, _ in ranked[:elite_size]
    ]

    while len(new_population) < population_size:

      parent1 = selection(population, fitnesses)
      parent2 = selection(population, fitnesses)

      child = crossover(parent1, parent2)

      child = mutate(
        child,
        mutation_rate,
        xmin, xmax,
        ymin, ymax
      )

      new_population.append(child)

    population = new_population

    if generation % 25 == 0:
      print(
        f"Generation {generation:3d} "
        f"Best Distance = {best_fitness:.4f}"
      )

  return best_solution, best_fitness

def plot_results(points, best_point):

  xmin = np.min(points[:, 0])
  xmax = np.max(points[:, 0])

  ymin = np.min(points[:, 1])
  ymax = np.max(points[:, 1])

  fig, ax = plt.subplots(figsize=(8, 6))

  # Input points
  ax.scatter(
    points[:, 0],
    points[:, 1],
    s=80,
    marker='o',
    label='Input Points'
  )

  # Optimized point
  ax.scatter(
    best_point[0],
    best_point[1],
    s=250,
    marker='*',
    label='Optimal Point'
  )

  # Bounding box
  rect = Rectangle(
    (xmin, ymin),
    xmax - xmin,
    ymax - ymin,
    fill=False,
    linestyle='--',
    linewidth=2,
    label='Bounding Box'
  )

  ax.add_patch(rect)

  # Draw distance lines
  for px, py in points:
    ax.plot(
      [best_point[0], px],
      [best_point[1], py],
      linewidth=0.5
    )

  ax.set_title(
    'Geometric Median via Genetic Algorithm'
  )

  ax.set_xlabel('X')
  ax.set_ylabel('Y')

  ax.grid(True)
  ax.axis('equal')
  ax.legend()

  plt.tight_layout()
  plt.show()
  
# ----- ----- ----- ----- ----- ----- ---- ----- ----- ----- ----- -----
# Main
# ----- ----- ----- ----- ----- ----- ---- ----- ----- ----- ----- -----
if __name__ == "__main__":

  points = load_points(csv_file)

  best_point, best_distance = genetic_algorithm(
    points,
    population_size = n_population,
    generations     = n_generations,
    mutation_rate   = rate_mutation,
    elite_size      = n_elite
  )

  print("\nOptimal Point")
  print("-------------------")
  print(f"X = {best_point[0]:.4f}")
  print(f"Y = {best_point[1]:.4f}")
  print(f"Total Distance = {best_distance:.4f}")
  
  plot_results(points, best_point)
  
  
