Refactor parent grouping making it more resilient
This commit is contained in:
parent
48737fd6f0
commit
35ca73ba74
|
@ -1,4 +1,4 @@
|
||||||
from numpy import sum, append, intersect1d
|
from numpy import sum, append, intersect1d, array_equal
|
||||||
from numpy.random import randint, choice, shuffle
|
from numpy.random import randint, choice, shuffle
|
||||||
from pandas import DataFrame
|
from pandas import DataFrame
|
||||||
from math import ceil
|
from math import ceil
|
||||||
|
@ -133,15 +133,28 @@ def position_crossover(parents):
|
||||||
return first_offspring, second_offspring
|
return first_offspring, second_offspring
|
||||||
|
|
||||||
|
|
||||||
|
def group_parents(parents):
|
||||||
|
parent_pairs = []
|
||||||
|
for i in range(0, len(parents), 2):
|
||||||
|
first = parents[i]
|
||||||
|
second = parents[i + 1]
|
||||||
|
if array_equal(first.point.values, second.point.values):
|
||||||
|
tmp = second
|
||||||
|
second = parents[i - 2]
|
||||||
|
parents[i - 2] = tmp
|
||||||
|
parent_pairs.append([first, second])
|
||||||
|
return parent_pairs
|
||||||
|
|
||||||
|
|
||||||
def crossover(mode, parents, m):
|
def crossover(mode, parents, m):
|
||||||
split_parents = list(zip(*[iter(parents)] * 2))
|
parent_groups = group_parents(parents)
|
||||||
offspring = []
|
offspring = []
|
||||||
if mode == "uniform":
|
if mode == "uniform":
|
||||||
for element in split_parents:
|
for element in parent_groups:
|
||||||
offspring.append(uniform_crossover(element, m))
|
offspring.append(uniform_crossover(element, m))
|
||||||
offspring.append(uniform_crossover(element, m))
|
offspring.append(uniform_crossover(element, m))
|
||||||
else:
|
else:
|
||||||
for element in split_parents:
|
for element in parent_groups:
|
||||||
first_offspring, second_offspring = position_crossover(element)
|
first_offspring, second_offspring = position_crossover(element)
|
||||||
offspring.append(first_offspring)
|
offspring.append(first_offspring)
|
||||||
offspring.append(second_offspring)
|
offspring.append(second_offspring)
|
||||||
|
@ -194,7 +207,7 @@ def tournament_selection(population):
|
||||||
|
|
||||||
def check_element_population(element, population):
|
def check_element_population(element, population):
|
||||||
for item in population:
|
for item in population:
|
||||||
if all(element.point.values) == all(item.point.values):
|
if array_equal(element.point.values, item.point.values):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue