diff --git a/src/local_search.py b/src/local_search.py index 8cff412..3567ba6 100644 --- a/src/local_search.py +++ b/src/local_search.py @@ -7,11 +7,18 @@ def get_first_random_solution(m, data): return data.loc[random_indexes] +def element_in_dataframe(solution, element): + duplicates = solution.query( + f"(source == {element.source} and destination == {element.destination}) or (source == {element.destination} and destination == {element.source})" + ) + return not duplicates.empty + + def replace_worst_element(previous, data): solution = previous.copy() worst_index = solution["distance"].astype(float).idxmin() random_element = data.sample().squeeze() - while solution.isin(random_element.values.ravel()).any().any(): + while element_in_dataframe(solution=solution, element=random_element): random_element = data.sample().squeeze() solution.loc[worst_index] = random_element return solution, worst_index