Add reference sequence to each dataset instance
This commit is contained in:
parent
f30fc31c29
commit
02d20d4e72
|
@ -1,4 +1,4 @@
|
||||||
from typing import List
|
from typing import List, Tuple
|
||||||
|
|
||||||
from Bio.motifs import create
|
from Bio.motifs import create
|
||||||
from Bio.SeqIO import parse
|
from Bio.SeqIO import parse
|
||||||
|
@ -8,10 +8,10 @@ from tensorflow.data import TFRecordDataset
|
||||||
from tensorflow.io import FixedLenFeature, TFRecordWriter, parse_single_example
|
from tensorflow.io import FixedLenFeature, TFRecordWriter, parse_single_example
|
||||||
from tensorflow.train import Example, Feature, Features, FloatList, Int64List
|
from tensorflow.train import Example, Feature, Features, FloatList, Int64List
|
||||||
|
|
||||||
from constants import BASES, BATCH_SIZE, EPOCHS, TEST_DATASET, TRAIN_DATASET
|
from constants import *
|
||||||
|
|
||||||
|
|
||||||
def generate_example(sequence, weight_matrix) -> bytes:
|
def generate_example(sequence, reference_sequence, weight_matrix) -> bytes:
|
||||||
"""
|
"""
|
||||||
Create a binary-string for each sequence containing the sequence and the bases' frequency
|
Create a binary-string for each sequence containing the sequence and the bases' frequency
|
||||||
"""
|
"""
|
||||||
|
@ -19,6 +19,9 @@ def generate_example(sequence, weight_matrix) -> bytes:
|
||||||
"sequence": Feature(
|
"sequence": Feature(
|
||||||
int64_list=Int64List(value=list(encode_sequence(sequence)))
|
int64_list=Int64List(value=list(encode_sequence(sequence)))
|
||||||
),
|
),
|
||||||
|
"reference_sequence": Feature(
|
||||||
|
int64_list=Int64List(value=list(encode_sequence(reference_sequence)))
|
||||||
|
),
|
||||||
"A_counts": Feature(float_list=FloatList(value=[weight_matrix["A"][0]])),
|
"A_counts": Feature(float_list=FloatList(value=[weight_matrix["A"][0]])),
|
||||||
"C_counts": Feature(float_list=FloatList(value=[weight_matrix["C"][0]])),
|
"C_counts": Feature(float_list=FloatList(value=[weight_matrix["C"][0]])),
|
||||||
"G_counts": Feature(float_list=FloatList(value=[weight_matrix["G"][0]])),
|
"G_counts": Feature(float_list=FloatList(value=[weight_matrix["G"][0]])),
|
||||||
|
@ -36,16 +39,19 @@ def encode_sequence(sequence) -> List[int]:
|
||||||
return encoded_sequence
|
return encoded_sequence
|
||||||
|
|
||||||
|
|
||||||
def read_fastq(filepath) -> List[bytes]:
|
def read_fastq(data_file, label_file) -> List[bytes]:
|
||||||
"""
|
"""
|
||||||
Parse a FASTQ file and generate a List of serialized Examples
|
Parses a data and a label FASTQ files and generates a List of serialized Examples
|
||||||
"""
|
"""
|
||||||
examples = []
|
examples = []
|
||||||
with open(filepath) as handle:
|
with open(data_file) as data, open(label_file) as labels:
|
||||||
for row in parse(handle, "fastq"):
|
for element, label in zip(parse(data, "fastq"), parse(labels, "fastq")):
|
||||||
sequence = str(row.seq)
|
motifs = create([element.seq])
|
||||||
motifs = create(row.seq)
|
example = generate_example(
|
||||||
example = generate_example(sequence=sequence, weight_matrix=motifs.pwm)
|
sequence=str(element.seq),
|
||||||
|
reference_sequence=str(label.seq),
|
||||||
|
weight_matrix=motifs.pwm,
|
||||||
|
)
|
||||||
examples.append(example)
|
examples.append(example)
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
|
@ -54,7 +60,7 @@ def create_dataset(filepath) -> None:
|
||||||
"""
|
"""
|
||||||
Create a training and test dataset with a 70/30 split respectively
|
Create a training and test dataset with a 70/30 split respectively
|
||||||
"""
|
"""
|
||||||
data = read_fastq(filepath)
|
data = read_fastq(data_file, label_file)
|
||||||
train_test_split = 0.7
|
train_test_split = 0.7
|
||||||
with TFRecordWriter(TRAIN_DATASET) as train, TFRecordWriter(TEST_DATASET) as test:
|
with TFRecordWriter(TRAIN_DATASET) as train, TFRecordWriter(TEST_DATASET) as test:
|
||||||
for element in data:
|
for element in data:
|
||||||
|
@ -70,6 +76,7 @@ def process_input(byte_string) -> Example:
|
||||||
"""
|
"""
|
||||||
schema = {
|
schema = {
|
||||||
"sequence": FixedLenFeature(shape=[], dtype=int64),
|
"sequence": FixedLenFeature(shape=[], dtype=int64),
|
||||||
|
"reference_sequence": FixedLenFeature(shape=[], dtype=int64),
|
||||||
"A_counts": FixedLenFeature(shape=[], dtype=float32),
|
"A_counts": FixedLenFeature(shape=[], dtype=float32),
|
||||||
"C_counts": FixedLenFeature(shape=[], dtype=float32),
|
"C_counts": FixedLenFeature(shape=[], dtype=float32),
|
||||||
"G_counts": FixedLenFeature(shape=[], dtype=float32),
|
"G_counts": FixedLenFeature(shape=[], dtype=float32),
|
||||||
|
@ -78,12 +85,19 @@ def process_input(byte_string) -> Example:
|
||||||
return parse_single_example(byte_string, features=schema)
|
return parse_single_example(byte_string, features=schema)
|
||||||
|
|
||||||
|
|
||||||
def read_dataset() -> TFRecordDataset:
|
def read_dataset(filepath) -> TFRecordDataset:
|
||||||
"""
|
"""
|
||||||
Read TFRecords files and generate a dataset
|
Read TFRecords files and generate a dataset
|
||||||
"""
|
"""
|
||||||
data_input = TFRecordDataset(filenames=[TRAIN_DATASET, TEST_DATASET])
|
data_input = TFRecordDataset(filenames=filepath)
|
||||||
dataset = data_input.map(map_func=process_input)
|
dataset = data_input.map(map_func=process_input)
|
||||||
shuffled_dataset = dataset.shuffle(buffer_size=10000, reshuffle_each_iteration=True)
|
shuffled_dataset = dataset.shuffle(buffer_size=10000, seed=42)
|
||||||
batched_dataset = shuffled_dataset.batch(batch_size=BATCH_SIZE).repeat(count=EPOCHS)
|
batched_dataset = shuffled_dataset.batch(batch_size=BATCH_SIZE).repeat(count=EPOCHS)
|
||||||
return batched_dataset
|
return batched_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def dataset_creation(data_file, label_file) -> Tuple[TFRecordDataset, TFRecordDataset]:
|
||||||
|
create_dataset(data_file, label_file)
|
||||||
|
train_data = read_dataset(TRAIN_DATASET)
|
||||||
|
test_data = read_dataset(TEST_DATASET)
|
||||||
|
return train_data, test_data
|
||||||
|
|
Loading…
Reference in New Issue