diff --git a/src/model.py b/src/model.py index de6e87c..dbf3bbe 100644 --- a/src/model.py +++ b/src/model.py @@ -1,10 +1,8 @@ from tensorflow.keras import Model, Sequential, layers from tensorflow.keras.regularizers import l2 -from preprocessing import BASES - -def build_model(hyper_parameters) -> Model: +def build_model(hyper_parameters, bases="ACGT") -> Model: """ Builds the CNN model """ @@ -41,6 +39,6 @@ def build_model(hyper_parameters) -> Model: ), layers.Dropout(rate=0.3), # Output layer with softmax activation - layers.Dense(units=len(BASES), activation="softmax"), + layers.Dense(units=len(bases), activation="softmax"), ] ) diff --git a/src/preprocessing.py b/src/preprocessing.py index f04b6d6..93fa5b6 100644 --- a/src/preprocessing.py +++ b/src/preprocessing.py @@ -4,8 +4,6 @@ from numpy.random import random from tensorflow.io import TFRecordWriter from tensorflow.train import BytesList, Example, Feature, Features, FloatList -BASES = "ACGT" - def generate_example(sequence, weight_matrix): schema = {