diff --git a/src/local_search.py b/src/local_search.py index 8377654..2b5d7fa 100644 --- a/src/local_search.py +++ b/src/local_search.py @@ -3,14 +3,17 @@ from numpy.random import choice, seed def get_first_random_solution(m, data): seed(42) - random_indexes = choice(len(data.index), size=m) - return data.iloc[random_indexes] + random_indexes = choice(len(data.index), size=m, replace=False) + return data.loc[random_indexes] def replace_worst_element(previous, data): solution = previous.copy() worst_index = solution["distance"].astype(float).idxmin() - solution.loc[worst_index] = data.sample(random_state=42).squeeze() + random_element = data.sample().squeeze() + while solution.isin(random_element.values.ravel()).any().any(): + random_element = data.sample().squeeze() + solution.loc[worst_index] = random_element return solution, worst_index @@ -28,11 +31,6 @@ def get_random_solution(previous, data): return solution -def remove_duplicates(element, data): - duplicate_free_df = data.query( - "(source not in @element.source) or (destination not in @element.destination)" - ) - return duplicate_free_df def explore_neighbourhood(element, data, max_iterations=100000): @@ -41,7 +39,6 @@ def explore_neighbourhood(element, data, max_iterations=100000): for _ in range(max_iterations): previous_solution = neighbourhood[-1] neighbour = get_random_solution(previous=previous_solution, data=data) - data = remove_duplicates(element=neighbour, data=data) if neighbour.equals(previous_solution): break neighbourhood.append(neighbour)