Compare commits

...

2 Commits

2 changed files with 12 additions and 8 deletions

View File

@ -3,3 +3,5 @@ TRAIN_DATASET = "data/train_data.tfrecords"
TEST_DATASET = "data/test_data.tfrecords" TEST_DATASET = "data/test_data.tfrecords"
EPOCHS = 1000 EPOCHS = 1000
BATCH_SIZE = 256 BATCH_SIZE = 256
LEARNING_RATE = 0.004
L2 = 0.001

View File

@ -36,7 +36,7 @@ def encode_sequence(sequence) -> List[int]:
return encoded_sequence return encoded_sequence
def parse_data(filepath) -> List[bytes]: def read_fastq(filepath) -> List[bytes]:
""" """
Parse a FASTQ file and generate a List of serialized Examples Parse a FASTQ file and generate a List of serialized Examples
""" """
@ -54,7 +54,7 @@ def create_dataset(filepath) -> None:
""" """
Create a training and test dataset with a 70/30 split respectively Create a training and test dataset with a 70/30 split respectively
""" """
data = parse_data(filepath) data = read_fastq(filepath)
train_test_split = 0.7 train_test_split = 0.7
with TFRecordWriter(TRAIN_DATASET) as train, TFRecordWriter(TEST_DATASET) as test: with TFRecordWriter(TRAIN_DATASET) as train, TFRecordWriter(TEST_DATASET) as test:
for element in data: for element in data:
@ -64,7 +64,10 @@ def create_dataset(filepath) -> None:
test.write(element) test.write(element)
def process_input(byte_string): def process_input(byte_string) -> Example:
"""
Parse a byte-string into an Example object
"""
schema = { schema = {
"sequence": FixedLenFeature(shape=[], dtype=int64), "sequence": FixedLenFeature(shape=[], dtype=int64),
"A_counts": FixedLenFeature(shape=[], dtype=float32), "A_counts": FixedLenFeature(shape=[], dtype=float32),
@ -75,13 +78,12 @@ def process_input(byte_string):
return parse_single_example(byte_string, features=schema) return parse_single_example(byte_string, features=schema)
def read_dataset(): def read_dataset() -> TFRecordDataset:
"""
Read TFRecords files and generate a dataset
"""
data_input = TFRecordDataset(filenames=[TRAIN_DATASET, TEST_DATASET]) data_input = TFRecordDataset(filenames=[TRAIN_DATASET, TEST_DATASET])
dataset = data_input.map(map_func=process_input) dataset = data_input.map(map_func=process_input)
shuffled_dataset = dataset.shuffle(buffer_size=10000, reshuffle_each_iteration=True) shuffled_dataset = dataset.shuffle(buffer_size=10000, reshuffle_each_iteration=True)
batched_dataset = shuffled_dataset.batch(batch_size=BATCH_SIZE).repeat(count=EPOCHS) batched_dataset = shuffled_dataset.batch(batch_size=BATCH_SIZE).repeat(count=EPOCHS)
return batched_dataset return batched_dataset
create_dataset("data/curesim-HVR.fastq")
dataset = read_dataset()