diff --git a/src/preprocessing.py b/src/preprocessing.py index 1ab533e..98fb80a 100644 --- a/src/preprocessing.py +++ b/src/preprocessing.py @@ -36,7 +36,7 @@ def encode_sequence(sequence) -> List[int]: return encoded_sequence -def parse_data(filepath) -> List[bytes]: +def read_fastq(filepath) -> List[bytes]: """ Parse a FASTQ file and generate a List of serialized Examples """ @@ -54,7 +54,7 @@ def create_dataset(filepath) -> None: """ Create a training and test dataset with a 70/30 split respectively """ - data = parse_data(filepath) + data = read_fastq(filepath) train_test_split = 0.7 with TFRecordWriter(TRAIN_DATASET) as train, TFRecordWriter(TEST_DATASET) as test: for element in data: @@ -64,7 +64,10 @@ def create_dataset(filepath) -> None: test.write(element) -def process_input(byte_string): +def process_input(byte_string) -> Example: + """ + Parse a byte-string into an Example object + """ schema = { "sequence": FixedLenFeature(shape=[], dtype=int64), "A_counts": FixedLenFeature(shape=[], dtype=float32), @@ -75,13 +78,12 @@ def process_input(byte_string): return parse_single_example(byte_string, features=schema) -def read_dataset(): +def read_dataset() -> TFRecordDataset: + """ + Read TFRecords files and generate a dataset + """ data_input = TFRecordDataset(filenames=[TRAIN_DATASET, TEST_DATASET]) dataset = data_input.map(map_func=process_input) shuffled_dataset = dataset.shuffle(buffer_size=10000, reshuffle_each_iteration=True) batched_dataset = shuffled_dataset.batch(batch_size=BATCH_SIZE).repeat(count=EPOCHS) return batched_dataset - - -create_dataset("data/curesim-HVR.fastq") -dataset = read_dataset()