From 220c0482f16c7d0e2c600b333fbadb8ae2dad761 Mon Sep 17 00:00:00 2001 From: coolneng Date: Tue, 1 Jun 2021 19:27:10 +0200 Subject: [PATCH] Move hardcorded data to a constants module --- src/constants.py | 3 +++ src/model.py | 6 ++++-- src/preprocessing.py | 15 ++++++++++----- 3 files changed, 17 insertions(+), 7 deletions(-) create mode 100644 src/constants.py diff --git a/src/constants.py b/src/constants.py new file mode 100644 index 0000000..c5d64bc --- /dev/null +++ b/src/constants.py @@ -0,0 +1,3 @@ +BASES = "ACGT" +TRAIN_DATASET = "data/train_data.tfrecords" +TEST_DATASET = "data/test_data.tfrecords" diff --git a/src/model.py b/src/model.py index dbf3bbe..6c43f2c 100644 --- a/src/model.py +++ b/src/model.py @@ -1,8 +1,10 @@ from tensorflow.keras import Model, Sequential, layers from tensorflow.keras.regularizers import l2 +from constants import BASES -def build_model(hyper_parameters, bases="ACGT") -> Model: + +def build_model(hyper_parameters) -> Model: """ Builds the CNN model """ @@ -39,6 +41,6 @@ def build_model(hyper_parameters, bases="ACGT") -> Model: ), layers.Dropout(rate=0.3), # Output layer with softmax activation - layers.Dense(units=len(bases), activation="softmax"), + layers.Dense(units=len(BASES), activation="softmax"), ] ) diff --git a/src/preprocessing.py b/src/preprocessing.py index 9904d73..68aad09 100644 --- a/src/preprocessing.py +++ b/src/preprocessing.py @@ -2,9 +2,12 @@ from Bio.motifs import create from Bio.SeqIO import parse from numpy.random import random from tensorflow.io import TFRecordWriter +from tensorflow.data import TFRecordDataset from tensorflow.train import BytesList, Example, Feature, Features, FloatList from typing import List +from constants import TRAIN_DATASET, TEST_DATASET + def generate_example(sequence, weight_matrix) -> bytes: """ @@ -41,14 +44,16 @@ def create_dataset(filepath) -> None: """ data = parse_data(filepath) train_test_split = 0.7 - with TFRecordWriter("data/train_data.tfrecords") as train_writer, TFRecordWriter( - "data/test_data.tfrecords" - ) as test_writer: + with TFRecordWriter(TRAIN_DATASET) as train, TFRecordWriter(TEST_DATASET) as test: for element in data: if random() < train_test_split: - train_writer.write(element) + train.write(element) else: - test_writer.write(element) + test.write(element) + + +def read_dataset(): + pass create_dataset("data/curesim-HVR.fastq")