Move hardcorded data to a constants module

This commit is contained in:
coolneng 2021-06-01 19:27:10 +02:00
parent 44ff69dc9e
commit 220c0482f1
Signed by: coolneng
GPG Key ID: 9893DA236405AF57
3 changed files with 17 additions and 7 deletions

3
src/constants.py Normal file
View File

@ -0,0 +1,3 @@
BASES = "ACGT"
TRAIN_DATASET = "data/train_data.tfrecords"
TEST_DATASET = "data/test_data.tfrecords"

View File

@ -1,8 +1,10 @@
from tensorflow.keras import Model, Sequential, layers from tensorflow.keras import Model, Sequential, layers
from tensorflow.keras.regularizers import l2 from tensorflow.keras.regularizers import l2
from constants import BASES
def build_model(hyper_parameters, bases="ACGT") -> Model:
def build_model(hyper_parameters) -> Model:
""" """
Builds the CNN model Builds the CNN model
""" """
@ -39,6 +41,6 @@ def build_model(hyper_parameters, bases="ACGT") -> Model:
), ),
layers.Dropout(rate=0.3), layers.Dropout(rate=0.3),
# Output layer with softmax activation # Output layer with softmax activation
layers.Dense(units=len(bases), activation="softmax"), layers.Dense(units=len(BASES), activation="softmax"),
] ]
) )

View File

@ -2,9 +2,12 @@ from Bio.motifs import create
from Bio.SeqIO import parse from Bio.SeqIO import parse
from numpy.random import random from numpy.random import random
from tensorflow.io import TFRecordWriter from tensorflow.io import TFRecordWriter
from tensorflow.data import TFRecordDataset
from tensorflow.train import BytesList, Example, Feature, Features, FloatList from tensorflow.train import BytesList, Example, Feature, Features, FloatList
from typing import List from typing import List
from constants import TRAIN_DATASET, TEST_DATASET
def generate_example(sequence, weight_matrix) -> bytes: def generate_example(sequence, weight_matrix) -> bytes:
""" """
@ -41,14 +44,16 @@ def create_dataset(filepath) -> None:
""" """
data = parse_data(filepath) data = parse_data(filepath)
train_test_split = 0.7 train_test_split = 0.7
with TFRecordWriter("data/train_data.tfrecords") as train_writer, TFRecordWriter( with TFRecordWriter(TRAIN_DATASET) as train, TFRecordWriter(TEST_DATASET) as test:
"data/test_data.tfrecords"
) as test_writer:
for element in data: for element in data:
if random() < train_test_split: if random() < train_test_split:
train_writer.write(element) train.write(element)
else: else:
test_writer.write(element) test.write(element)
def read_dataset():
pass
create_dataset("data/curesim-HVR.fastq") create_dataset("data/curesim-HVR.fastq")