From d34e291085b42a3dbdc6be2c89654d12a040be72 Mon Sep 17 00:00:00 2001 From: coolneng Date: Tue, 1 Jun 2021 23:06:25 +0200 Subject: [PATCH] Generate a dataset from the TFRecords files --- src/constants.py | 2 ++ src/preprocessing.py | 26 ++++++++++++++++++++++---- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/constants.py b/src/constants.py index c5d64bc..a9be1d0 100644 --- a/src/constants.py +++ b/src/constants.py @@ -1,3 +1,5 @@ BASES = "ACGT" TRAIN_DATASET = "data/train_data.tfrecords" TEST_DATASET = "data/test_data.tfrecords" +EPOCHS = 1000 +BATCH_SIZE = 256 diff --git a/src/preprocessing.py b/src/preprocessing.py index 68aad09..34cb1b6 100644 --- a/src/preprocessing.py +++ b/src/preprocessing.py @@ -1,12 +1,14 @@ +from typing import List + from Bio.motifs import create from Bio.SeqIO import parse from numpy.random import random -from tensorflow.io import TFRecordWriter +from tensorflow import float32, string from tensorflow.data import TFRecordDataset +from tensorflow.io import FixedLenFeature, TFRecordWriter, parse_single_example from tensorflow.train import BytesList, Example, Feature, Features, FloatList -from typing import List -from constants import TRAIN_DATASET, TEST_DATASET +from constants import BATCH_SIZE, EPOCHS, TEST_DATASET, TRAIN_DATASET def generate_example(sequence, weight_matrix) -> bytes: @@ -52,8 +54,24 @@ def create_dataset(filepath) -> None: test.write(element) +def process_input(byte_string): + schema = { + "sequence": FixedLenFeature(shape=[], dtype=string), + "A_counts": FixedLenFeature(shape=[], dtype=float32), + "C_counts": FixedLenFeature(shape=[], dtype=float32), + "G_counts": FixedLenFeature(shape=[], dtype=float32), + "T_counts": FixedLenFeature(shape=[], dtype=float32), + } + return parse_single_example(byte_string, features=schema) + + def read_dataset(): - pass + data_input = TFRecordDataset(filenames=[TRAIN_DATASET, TEST_DATASET]) + dataset = data_input.map(map_func=process_input) + shuffled_dataset = dataset.shuffle(buffer_size=10000, reshuffle_each_iteration=True) + batched_dataset = shuffled_dataset.batch(batch_size=BATCH_SIZE).repeat(count=EPOCHS) + return batched_dataset create_dataset("data/curesim-HVR.fastq") +dataset = read_dataset()