Refactor parent grouping making it more resilient

This commit is contained in:
coolneng 2021-06-21 02:08:15 +02:00
parent 48737fd6f0
commit 35ca73ba74
Signed by: coolneng
GPG Key ID: 9893DA236405AF57
1 changed files with 18 additions and 5 deletions

View File

@ -1,4 +1,4 @@
from numpy import sum, append, intersect1d from numpy import sum, append, intersect1d, array_equal
from numpy.random import randint, choice, shuffle from numpy.random import randint, choice, shuffle
from pandas import DataFrame from pandas import DataFrame
from math import ceil from math import ceil
@ -133,15 +133,28 @@ def position_crossover(parents):
return first_offspring, second_offspring return first_offspring, second_offspring
def group_parents(parents):
parent_pairs = []
for i in range(0, len(parents), 2):
first = parents[i]
second = parents[i + 1]
if array_equal(first.point.values, second.point.values):
tmp = second
second = parents[i - 2]
parents[i - 2] = tmp
parent_pairs.append([first, second])
return parent_pairs
def crossover(mode, parents, m): def crossover(mode, parents, m):
split_parents = list(zip(*[iter(parents)] * 2)) parent_groups = group_parents(parents)
offspring = [] offspring = []
if mode == "uniform": if mode == "uniform":
for element in split_parents: for element in parent_groups:
offspring.append(uniform_crossover(element, m)) offspring.append(uniform_crossover(element, m))
offspring.append(uniform_crossover(element, m)) offspring.append(uniform_crossover(element, m))
else: else:
for element in split_parents: for element in parent_groups:
first_offspring, second_offspring = position_crossover(element) first_offspring, second_offspring = position_crossover(element)
offspring.append(first_offspring) offspring.append(first_offspring)
offspring.append(second_offspring) offspring.append(second_offspring)
@ -194,7 +207,7 @@ def tournament_selection(population):
def check_element_population(element, population): def check_element_population(element, population):
for item in population: for item in population:
if all(element.point.values) == all(item.point.values): if array_equal(element.point.values, item.point.values):
return True return True
return False return False