Generate a dataset from the TFRecords files

This commit is contained in:
coolneng 2021-06-01 23:06:25 +02:00
parent 220c0482f1
commit d34e291085
Signed by: coolneng
GPG Key ID: 9893DA236405AF57
2 changed files with 24 additions and 4 deletions

View File

@ -1,3 +1,5 @@
BASES = "ACGT" BASES = "ACGT"
TRAIN_DATASET = "data/train_data.tfrecords" TRAIN_DATASET = "data/train_data.tfrecords"
TEST_DATASET = "data/test_data.tfrecords" TEST_DATASET = "data/test_data.tfrecords"
EPOCHS = 1000
BATCH_SIZE = 256

View File

@ -1,12 +1,14 @@
from typing import List
from Bio.motifs import create from Bio.motifs import create
from Bio.SeqIO import parse from Bio.SeqIO import parse
from numpy.random import random from numpy.random import random
from tensorflow.io import TFRecordWriter from tensorflow import float32, string
from tensorflow.data import TFRecordDataset from tensorflow.data import TFRecordDataset
from tensorflow.io import FixedLenFeature, TFRecordWriter, parse_single_example
from tensorflow.train import BytesList, Example, Feature, Features, FloatList from tensorflow.train import BytesList, Example, Feature, Features, FloatList
from typing import List
from constants import TRAIN_DATASET, TEST_DATASET from constants import BATCH_SIZE, EPOCHS, TEST_DATASET, TRAIN_DATASET
def generate_example(sequence, weight_matrix) -> bytes: def generate_example(sequence, weight_matrix) -> bytes:
@ -52,8 +54,24 @@ def create_dataset(filepath) -> None:
test.write(element) test.write(element)
def process_input(byte_string):
schema = {
"sequence": FixedLenFeature(shape=[], dtype=string),
"A_counts": FixedLenFeature(shape=[], dtype=float32),
"C_counts": FixedLenFeature(shape=[], dtype=float32),
"G_counts": FixedLenFeature(shape=[], dtype=float32),
"T_counts": FixedLenFeature(shape=[], dtype=float32),
}
return parse_single_example(byte_string, features=schema)
def read_dataset(): def read_dataset():
pass data_input = TFRecordDataset(filenames=[TRAIN_DATASET, TEST_DATASET])
dataset = data_input.map(map_func=process_input)
shuffled_dataset = dataset.shuffle(buffer_size=10000, reshuffle_each_iteration=True)
batched_dataset = shuffled_dataset.batch(batch_size=BATCH_SIZE).repeat(count=EPOCHS)
return batched_dataset
create_dataset("data/curesim-HVR.fastq") create_dataset("data/curesim-HVR.fastq")
dataset = read_dataset()