Remove base counts from the dataset

This commit is contained in:
coolneng 2021-06-16 13:02:49 +02:00
parent a2ae7bbe11
commit 1e433c123f
Signed by: coolneng
GPG Key ID: 9893DA236405AF57
2 changed files with 11 additions and 33 deletions

View File

@ -86,7 +86,6 @@ def run(data_file, label_file, seed_value=42) -> None:
epochs=EPOCHS,
validation_data=eval_data,
callbacks=[tensorboard],
verbose=0,
)
print("Training complete. Obtaining final metrics...")
show_metrics(model, eval_data, test_data)

View File

@ -1,31 +1,21 @@
from typing import List, Tuple
from typing import Dict, List, Tuple
from Bio.motifs import create
from Bio.SeqIO import parse
from numpy.random import random
from tensorflow import Tensor, int64, stack, cast, int32
from tensorflow.sparse import to_dense
from tensorflow import Tensor, int64
from tensorflow.data import TFRecordDataset
from tensorflow.io import (
FixedLenFeature,
TFRecordWriter,
VarLenFeature,
parse_single_example,
)
from tensorflow.io import TFRecordWriter, VarLenFeature, parse_single_example
from tensorflow.sparse import to_dense
from tensorflow.train import Example, Feature, Features, Int64List
from constants import *
def generate_example(sequence, label, base_counts) -> bytes:
def generate_example(sequence, label) -> bytes:
"""
Create a binary-string for each sequence containing the sequence and the bases' counts
"""
schema = {
"A_counts": Feature(int64_list=Int64List(value=[sum(base_counts["A"])])),
"C_counts": Feature(int64_list=Int64List(value=[sum(base_counts["C"])])),
"G_counts": Feature(int64_list=Int64List(value=[sum(base_counts["G"])])),
"T_counts": Feature(int64_list=Int64List(value=[sum(base_counts["T"])])),
"sequence": Feature(int64_list=Int64List(value=encode_sequence(sequence))),
"label": Feature(int64_list=Int64List(value=encode_sequence(label))),
}
@ -48,11 +38,9 @@ def read_fastq(data_file, label_file) -> List[bytes]:
examples = []
with open(data_file) as data, open(label_file) as labels:
for element, label in zip(parse(data, "fastq"), parse(labels, "fastq")):
motifs = create([element.seq])
example = generate_example(
sequence=str(element.seq),
label=str(label.seq),
base_counts=motifs.counts,
)
examples.append(example)
return examples
@ -77,19 +65,14 @@ def create_dataset(
test.write(element)
def transform_features(parsed_features) -> List[Tensor]:
def transform_features(parsed_features) -> Dict[str, Tensor]:
"""
Cast and transform the parsed features of an Example into a list of Tensors
Transform the parsed features of an Example into a list of dense Tensors
"""
features = {}
sparse_features = ["sequence", "label"]
for feature in sparse_features:
parsed_features[feature] = cast(parsed_features[feature], int32)
parsed_features[feature] = to_dense(parsed_features[feature])
for base in BASES:
parsed_features[f"{base}_counts"] = cast(
parsed_features[f"{base}_counts"], int32
)
features = list(parsed_features.values())[:-1]
for element in sparse_features:
features[element] = to_dense(parsed_features[element])
return features
@ -98,16 +81,12 @@ def process_input(byte_string) -> Tuple[Tensor, Tensor]:
Parse a byte-string into an Example object
"""
schema = {
"A_counts": FixedLenFeature(shape=[1], dtype=int64),
"C_counts": FixedLenFeature(shape=[1], dtype=int64),
"G_counts": FixedLenFeature(shape=[1], dtype=int64),
"T_counts": FixedLenFeature(shape=[1], dtype=int64),
"sequence": VarLenFeature(dtype=int64),
"label": VarLenFeature(dtype=int64),
}
parsed_features = parse_single_example(byte_string, features=schema)
features = transform_features(parsed_features)
return stack(features, axis=-1), parsed_features["label"]
return features["sequence"], features["label"]
def read_dataset(filepath) -> TFRecordDataset: