diff --git a/src/genetic_algorithm.py b/src/genetic_algorithm.py index 7561fb5..91d9136 100644 --- a/src/genetic_algorithm.py +++ b/src/genetic_algorithm.py @@ -1,4 +1,4 @@ -from numpy import sum, append, intersect1d +from numpy import sum, append, intersect1d, array_equal from numpy.random import randint, choice, shuffle from pandas import DataFrame from math import ceil @@ -133,15 +133,28 @@ def position_crossover(parents): return first_offspring, second_offspring +def group_parents(parents): + parent_pairs = [] + for i in range(0, len(parents), 2): + first = parents[i] + second = parents[i + 1] + if array_equal(first.point.values, second.point.values): + tmp = second + second = parents[i - 2] + parents[i - 2] = tmp + parent_pairs.append([first, second]) + return parent_pairs + + def crossover(mode, parents, m): - split_parents = list(zip(*[iter(parents)] * 2)) + parent_groups = group_parents(parents) offspring = [] if mode == "uniform": - for element in split_parents: + for element in parent_groups: offspring.append(uniform_crossover(element, m)) offspring.append(uniform_crossover(element, m)) else: - for element in split_parents: + for element in parent_groups: first_offspring, second_offspring = position_crossover(element) offspring.append(first_offspring) offspring.append(second_offspring) @@ -194,7 +207,7 @@ def tournament_selection(population): def check_element_population(element, population): for item in population: - if all(element.point.values) == all(item.point.values): + if array_equal(element.point.values, item.point.values): return True return False