Compare commits
2 Commits
035162bd8d
...
2920db70b4
Author | SHA1 | Date |
---|---|---|
coolneng | 2920db70b4 | |
coolneng | 38903c5737 |
|
@ -11,7 +11,7 @@ from tensorflow.train import Example, Feature, Features, FloatList, Int64List
|
||||||
from constants import *
|
from constants import *
|
||||||
|
|
||||||
|
|
||||||
def generate_example(sequence, reference_sequence, weight_matrix) -> bytes:
|
def generate_example(sequence, label, weight_matrix) -> bytes:
|
||||||
"""
|
"""
|
||||||
Create a binary-string for each sequence containing the sequence and the bases' frequency
|
Create a binary-string for each sequence containing the sequence and the bases' frequency
|
||||||
"""
|
"""
|
||||||
|
@ -19,9 +19,7 @@ def generate_example(sequence, reference_sequence, weight_matrix) -> bytes:
|
||||||
"sequence": Feature(
|
"sequence": Feature(
|
||||||
int64_list=Int64List(value=list(encode_sequence(sequence)))
|
int64_list=Int64List(value=list(encode_sequence(sequence)))
|
||||||
),
|
),
|
||||||
"reference_sequence": Feature(
|
"label": Feature(int64_list=Int64List(value=list(encode_sequence(label)))),
|
||||||
int64_list=Int64List(value=list(encode_sequence(reference_sequence)))
|
|
||||||
),
|
|
||||||
"A_counts": Feature(float_list=FloatList(value=weight_matrix["A"])),
|
"A_counts": Feature(float_list=FloatList(value=weight_matrix["A"])),
|
||||||
"C_counts": Feature(float_list=FloatList(value=weight_matrix["C"])),
|
"C_counts": Feature(float_list=FloatList(value=weight_matrix["C"])),
|
||||||
"G_counts": Feature(float_list=FloatList(value=weight_matrix["G"])),
|
"G_counts": Feature(float_list=FloatList(value=weight_matrix["G"])),
|
||||||
|
@ -49,23 +47,27 @@ def read_fastq(data_file, label_file) -> List[bytes]:
|
||||||
motifs = create([element.seq])
|
motifs = create([element.seq])
|
||||||
example = generate_example(
|
example = generate_example(
|
||||||
sequence=str(element.seq),
|
sequence=str(element.seq),
|
||||||
reference_sequence=str(label.seq),
|
label=str(label.seq),
|
||||||
weight_matrix=motifs.pwm,
|
weight_matrix=motifs.pwm,
|
||||||
)
|
)
|
||||||
examples.append(example)
|
examples.append(example)
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
|
|
||||||
def create_dataset(filepath) -> None:
|
def create_dataset(data_file, label_file) -> None:
|
||||||
"""
|
"""
|
||||||
Create a training and test dataset with a 70/30 split respectively
|
Create a training and test dataset with a 70/30 split respectively
|
||||||
"""
|
"""
|
||||||
data = read_fastq(data_file, label_file)
|
data = read_fastq(data_file, label_file)
|
||||||
train_test_split = 0.7
|
train_eval_test_split = [0.8, 0.1, 0.1]
|
||||||
with TFRecordWriter(TRAIN_DATASET) as train, TFRecordWriter(TEST_DATASET) as test:
|
with TFRecordWriter(TRAIN_DATASET) as training, TFRecordWriter(
|
||||||
|
TEST_DATASET
|
||||||
|
) as test, TFRecordWriter(EVAL_DATASET) as evaluation:
|
||||||
for element in data:
|
for element in data:
|
||||||
if random() < train_test_split:
|
if random() < train_eval_test_split[0]:
|
||||||
train.write(element)
|
training.write(element)
|
||||||
|
elif random() < train_eval_test_split[0] + train_eval_test_split[1]:
|
||||||
|
evaluation.write(element)
|
||||||
else:
|
else:
|
||||||
test.write(element)
|
test.write(element)
|
||||||
|
|
||||||
|
@ -76,7 +78,7 @@ def process_input(byte_string) -> Example:
|
||||||
"""
|
"""
|
||||||
schema = {
|
schema = {
|
||||||
"sequence": FixedLenFeature(shape=[], dtype=int64),
|
"sequence": FixedLenFeature(shape=[], dtype=int64),
|
||||||
"reference_sequence": FixedLenFeature(shape=[], dtype=int64),
|
"label": FixedLenFeature(shape=[], dtype=int64),
|
||||||
"A_counts": FixedLenFeature(shape=[], dtype=float32),
|
"A_counts": FixedLenFeature(shape=[], dtype=float32),
|
||||||
"C_counts": FixedLenFeature(shape=[], dtype=float32),
|
"C_counts": FixedLenFeature(shape=[], dtype=float32),
|
||||||
"G_counts": FixedLenFeature(shape=[], dtype=float32),
|
"G_counts": FixedLenFeature(shape=[], dtype=float32),
|
||||||
|
@ -96,8 +98,11 @@ def read_dataset(filepath) -> TFRecordDataset:
|
||||||
return batched_dataset
|
return batched_dataset
|
||||||
|
|
||||||
|
|
||||||
def dataset_creation(data_file, label_file) -> Tuple[TFRecordDataset, TFRecordDataset]:
|
def dataset_creation(
|
||||||
|
data_file, label_file
|
||||||
|
) -> Tuple[TFRecordDataset, TFRecordDataset, TFRecordDataset]:
|
||||||
create_dataset(data_file, label_file)
|
create_dataset(data_file, label_file)
|
||||||
train_data = read_dataset(TRAIN_DATASET)
|
train_data = read_dataset(TRAIN_DATASET)
|
||||||
|
eval_data = read_dataset(EVAL_DATASET)
|
||||||
test_data = read_dataset(TEST_DATASET)
|
test_data = read_dataset(TEST_DATASET)
|
||||||
return train_data, test_data
|
return train_data, eval_data, test_data
|
||||||
|
|
Loading…
Reference in New Issue