Remove base counts from the dataset
This commit is contained in:
parent
a2ae7bbe11
commit
1e433c123f
|
@ -86,7 +86,6 @@ def run(data_file, label_file, seed_value=42) -> None:
|
|||
epochs=EPOCHS,
|
||||
validation_data=eval_data,
|
||||
callbacks=[tensorboard],
|
||||
verbose=0,
|
||||
)
|
||||
print("Training complete. Obtaining final metrics...")
|
||||
show_metrics(model, eval_data, test_data)
|
||||
|
|
|
@ -1,31 +1,21 @@
|
|||
from typing import List, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from Bio.motifs import create
|
||||
from Bio.SeqIO import parse
|
||||
from numpy.random import random
|
||||
from tensorflow import Tensor, int64, stack, cast, int32
|
||||
from tensorflow.sparse import to_dense
|
||||
from tensorflow import Tensor, int64
|
||||
from tensorflow.data import TFRecordDataset
|
||||
from tensorflow.io import (
|
||||
FixedLenFeature,
|
||||
TFRecordWriter,
|
||||
VarLenFeature,
|
||||
parse_single_example,
|
||||
)
|
||||
from tensorflow.io import TFRecordWriter, VarLenFeature, parse_single_example
|
||||
from tensorflow.sparse import to_dense
|
||||
from tensorflow.train import Example, Feature, Features, Int64List
|
||||
|
||||
from constants import *
|
||||
|
||||
|
||||
def generate_example(sequence, label, base_counts) -> bytes:
|
||||
def generate_example(sequence, label) -> bytes:
|
||||
"""
|
||||
Create a binary-string for each sequence containing the sequence and the bases' counts
|
||||
"""
|
||||
schema = {
|
||||
"A_counts": Feature(int64_list=Int64List(value=[sum(base_counts["A"])])),
|
||||
"C_counts": Feature(int64_list=Int64List(value=[sum(base_counts["C"])])),
|
||||
"G_counts": Feature(int64_list=Int64List(value=[sum(base_counts["G"])])),
|
||||
"T_counts": Feature(int64_list=Int64List(value=[sum(base_counts["T"])])),
|
||||
"sequence": Feature(int64_list=Int64List(value=encode_sequence(sequence))),
|
||||
"label": Feature(int64_list=Int64List(value=encode_sequence(label))),
|
||||
}
|
||||
|
@ -48,11 +38,9 @@ def read_fastq(data_file, label_file) -> List[bytes]:
|
|||
examples = []
|
||||
with open(data_file) as data, open(label_file) as labels:
|
||||
for element, label in zip(parse(data, "fastq"), parse(labels, "fastq")):
|
||||
motifs = create([element.seq])
|
||||
example = generate_example(
|
||||
sequence=str(element.seq),
|
||||
label=str(label.seq),
|
||||
base_counts=motifs.counts,
|
||||
)
|
||||
examples.append(example)
|
||||
return examples
|
||||
|
@ -77,19 +65,14 @@ def create_dataset(
|
|||
test.write(element)
|
||||
|
||||
|
||||
def transform_features(parsed_features) -> List[Tensor]:
|
||||
def transform_features(parsed_features) -> Dict[str, Tensor]:
|
||||
"""
|
||||
Cast and transform the parsed features of an Example into a list of Tensors
|
||||
Transform the parsed features of an Example into a list of dense Tensors
|
||||
"""
|
||||
features = {}
|
||||
sparse_features = ["sequence", "label"]
|
||||
for feature in sparse_features:
|
||||
parsed_features[feature] = cast(parsed_features[feature], int32)
|
||||
parsed_features[feature] = to_dense(parsed_features[feature])
|
||||
for base in BASES:
|
||||
parsed_features[f"{base}_counts"] = cast(
|
||||
parsed_features[f"{base}_counts"], int32
|
||||
)
|
||||
features = list(parsed_features.values())[:-1]
|
||||
for element in sparse_features:
|
||||
features[element] = to_dense(parsed_features[element])
|
||||
return features
|
||||
|
||||
|
||||
|
@ -98,16 +81,12 @@ def process_input(byte_string) -> Tuple[Tensor, Tensor]:
|
|||
Parse a byte-string into an Example object
|
||||
"""
|
||||
schema = {
|
||||
"A_counts": FixedLenFeature(shape=[1], dtype=int64),
|
||||
"C_counts": FixedLenFeature(shape=[1], dtype=int64),
|
||||
"G_counts": FixedLenFeature(shape=[1], dtype=int64),
|
||||
"T_counts": FixedLenFeature(shape=[1], dtype=int64),
|
||||
"sequence": VarLenFeature(dtype=int64),
|
||||
"label": VarLenFeature(dtype=int64),
|
||||
}
|
||||
parsed_features = parse_single_example(byte_string, features=schema)
|
||||
features = transform_features(parsed_features)
|
||||
return stack(features, axis=-1), parsed_features["label"]
|
||||
return features["sequence"], features["label"]
|
||||
|
||||
|
||||
def read_dataset(filepath) -> TFRecordDataset:
|
||||
|
|
Loading…
Reference in New Issue