From 1e433c123fee823fc2bda8194929adf85e04fbd9 Mon Sep 17 00:00:00 2001 From: coolneng Date: Wed, 16 Jun 2021 13:02:49 +0200 Subject: [PATCH] Remove base counts from the dataset --- src/model.py | 1 - src/preprocessing.py | 43 +++++++++++-------------------------------- 2 files changed, 11 insertions(+), 33 deletions(-) diff --git a/src/model.py b/src/model.py index 3252c1e..8ef83ec 100644 --- a/src/model.py +++ b/src/model.py @@ -86,7 +86,6 @@ def run(data_file, label_file, seed_value=42) -> None: epochs=EPOCHS, validation_data=eval_data, callbacks=[tensorboard], - verbose=0, ) print("Training complete. Obtaining final metrics...") show_metrics(model, eval_data, test_data) diff --git a/src/preprocessing.py b/src/preprocessing.py index ca72b67..df705a6 100644 --- a/src/preprocessing.py +++ b/src/preprocessing.py @@ -1,31 +1,21 @@ -from typing import List, Tuple +from typing import Dict, List, Tuple -from Bio.motifs import create from Bio.SeqIO import parse from numpy.random import random -from tensorflow import Tensor, int64, stack, cast, int32 -from tensorflow.sparse import to_dense +from tensorflow import Tensor, int64 from tensorflow.data import TFRecordDataset -from tensorflow.io import ( - FixedLenFeature, - TFRecordWriter, - VarLenFeature, - parse_single_example, -) +from tensorflow.io import TFRecordWriter, VarLenFeature, parse_single_example +from tensorflow.sparse import to_dense from tensorflow.train import Example, Feature, Features, Int64List from constants import * -def generate_example(sequence, label, base_counts) -> bytes: +def generate_example(sequence, label) -> bytes: """ Create a binary-string for each sequence containing the sequence and the bases' counts """ schema = { - "A_counts": Feature(int64_list=Int64List(value=[sum(base_counts["A"])])), - "C_counts": Feature(int64_list=Int64List(value=[sum(base_counts["C"])])), - "G_counts": Feature(int64_list=Int64List(value=[sum(base_counts["G"])])), - "T_counts": Feature(int64_list=Int64List(value=[sum(base_counts["T"])])), "sequence": Feature(int64_list=Int64List(value=encode_sequence(sequence))), "label": Feature(int64_list=Int64List(value=encode_sequence(label))), } @@ -48,11 +38,9 @@ def read_fastq(data_file, label_file) -> List[bytes]: examples = [] with open(data_file) as data, open(label_file) as labels: for element, label in zip(parse(data, "fastq"), parse(labels, "fastq")): - motifs = create([element.seq]) example = generate_example( sequence=str(element.seq), label=str(label.seq), - base_counts=motifs.counts, ) examples.append(example) return examples @@ -77,19 +65,14 @@ def create_dataset( test.write(element) -def transform_features(parsed_features) -> List[Tensor]: +def transform_features(parsed_features) -> Dict[str, Tensor]: """ - Cast and transform the parsed features of an Example into a list of Tensors + Transform the parsed features of an Example into a list of dense Tensors """ + features = {} sparse_features = ["sequence", "label"] - for feature in sparse_features: - parsed_features[feature] = cast(parsed_features[feature], int32) - parsed_features[feature] = to_dense(parsed_features[feature]) - for base in BASES: - parsed_features[f"{base}_counts"] = cast( - parsed_features[f"{base}_counts"], int32 - ) - features = list(parsed_features.values())[:-1] + for element in sparse_features: + features[element] = to_dense(parsed_features[element]) return features @@ -98,16 +81,12 @@ def process_input(byte_string) -> Tuple[Tensor, Tensor]: Parse a byte-string into an Example object """ schema = { - "A_counts": FixedLenFeature(shape=[1], dtype=int64), - "C_counts": FixedLenFeature(shape=[1], dtype=int64), - "G_counts": FixedLenFeature(shape=[1], dtype=int64), - "T_counts": FixedLenFeature(shape=[1], dtype=int64), "sequence": VarLenFeature(dtype=int64), "label": VarLenFeature(dtype=int64), } parsed_features = parse_single_example(byte_string, features=schema) features = transform_features(parsed_features) - return stack(features, axis=-1), parsed_features["label"] + return features["sequence"], features["label"] def read_dataset(filepath) -> TFRecordDataset: