Remove base counts from the dataset

This commit is contained in:
coolneng 2021-06-16 13:02:49 +02:00
parent a2ae7bbe11
commit 1e433c123f
Signed by: coolneng
GPG Key ID: 9893DA236405AF57
2 changed files with 11 additions and 33 deletions

View File

@ -86,7 +86,6 @@ def run(data_file, label_file, seed_value=42) -> None:
epochs=EPOCHS, epochs=EPOCHS,
validation_data=eval_data, validation_data=eval_data,
callbacks=[tensorboard], callbacks=[tensorboard],
verbose=0,
) )
print("Training complete. Obtaining final metrics...") print("Training complete. Obtaining final metrics...")
show_metrics(model, eval_data, test_data) show_metrics(model, eval_data, test_data)

View File

@ -1,31 +1,21 @@
from typing import List, Tuple from typing import Dict, List, Tuple
from Bio.motifs import create
from Bio.SeqIO import parse from Bio.SeqIO import parse
from numpy.random import random from numpy.random import random
from tensorflow import Tensor, int64, stack, cast, int32 from tensorflow import Tensor, int64
from tensorflow.sparse import to_dense
from tensorflow.data import TFRecordDataset from tensorflow.data import TFRecordDataset
from tensorflow.io import ( from tensorflow.io import TFRecordWriter, VarLenFeature, parse_single_example
FixedLenFeature, from tensorflow.sparse import to_dense
TFRecordWriter,
VarLenFeature,
parse_single_example,
)
from tensorflow.train import Example, Feature, Features, Int64List from tensorflow.train import Example, Feature, Features, Int64List
from constants import * from constants import *
def generate_example(sequence, label, base_counts) -> bytes: def generate_example(sequence, label) -> bytes:
""" """
Create a binary-string for each sequence containing the sequence and the bases' counts Create a binary-string for each sequence containing the sequence and the bases' counts
""" """
schema = { schema = {
"A_counts": Feature(int64_list=Int64List(value=[sum(base_counts["A"])])),
"C_counts": Feature(int64_list=Int64List(value=[sum(base_counts["C"])])),
"G_counts": Feature(int64_list=Int64List(value=[sum(base_counts["G"])])),
"T_counts": Feature(int64_list=Int64List(value=[sum(base_counts["T"])])),
"sequence": Feature(int64_list=Int64List(value=encode_sequence(sequence))), "sequence": Feature(int64_list=Int64List(value=encode_sequence(sequence))),
"label": Feature(int64_list=Int64List(value=encode_sequence(label))), "label": Feature(int64_list=Int64List(value=encode_sequence(label))),
} }
@ -48,11 +38,9 @@ def read_fastq(data_file, label_file) -> List[bytes]:
examples = [] examples = []
with open(data_file) as data, open(label_file) as labels: with open(data_file) as data, open(label_file) as labels:
for element, label in zip(parse(data, "fastq"), parse(labels, "fastq")): for element, label in zip(parse(data, "fastq"), parse(labels, "fastq")):
motifs = create([element.seq])
example = generate_example( example = generate_example(
sequence=str(element.seq), sequence=str(element.seq),
label=str(label.seq), label=str(label.seq),
base_counts=motifs.counts,
) )
examples.append(example) examples.append(example)
return examples return examples
@ -77,19 +65,14 @@ def create_dataset(
test.write(element) test.write(element)
def transform_features(parsed_features) -> List[Tensor]: def transform_features(parsed_features) -> Dict[str, Tensor]:
""" """
Cast and transform the parsed features of an Example into a list of Tensors Transform the parsed features of an Example into a list of dense Tensors
""" """
features = {}
sparse_features = ["sequence", "label"] sparse_features = ["sequence", "label"]
for feature in sparse_features: for element in sparse_features:
parsed_features[feature] = cast(parsed_features[feature], int32) features[element] = to_dense(parsed_features[element])
parsed_features[feature] = to_dense(parsed_features[feature])
for base in BASES:
parsed_features[f"{base}_counts"] = cast(
parsed_features[f"{base}_counts"], int32
)
features = list(parsed_features.values())[:-1]
return features return features
@ -98,16 +81,12 @@ def process_input(byte_string) -> Tuple[Tensor, Tensor]:
Parse a byte-string into an Example object Parse a byte-string into an Example object
""" """
schema = { schema = {
"A_counts": FixedLenFeature(shape=[1], dtype=int64),
"C_counts": FixedLenFeature(shape=[1], dtype=int64),
"G_counts": FixedLenFeature(shape=[1], dtype=int64),
"T_counts": FixedLenFeature(shape=[1], dtype=int64),
"sequence": VarLenFeature(dtype=int64), "sequence": VarLenFeature(dtype=int64),
"label": VarLenFeature(dtype=int64), "label": VarLenFeature(dtype=int64),
} }
parsed_features = parse_single_example(byte_string, features=schema) parsed_features = parse_single_example(byte_string, features=schema)
features = transform_features(parsed_features) features = transform_features(parsed_features)
return stack(features, axis=-1), parsed_features["label"] return features["sequence"], features["label"]
def read_dataset(filepath) -> TFRecordDataset: def read_dataset(filepath) -> TFRecordDataset: