Convert sequence and label to VarLenFeature
This commit is contained in:
parent
c6d0d5959d
commit
19ed847d12
|
@ -3,27 +3,31 @@ from typing import List, Tuple
|
||||||
from Bio.motifs import create
|
from Bio.motifs import create
|
||||||
from Bio.SeqIO import parse
|
from Bio.SeqIO import parse
|
||||||
from numpy.random import random
|
from numpy.random import random
|
||||||
from tensorflow import float32, int64
|
from tensorflow import Tensor, int64, stack
|
||||||
|
from tensorflow.sparse import to_dense
|
||||||
from tensorflow.data import TFRecordDataset
|
from tensorflow.data import TFRecordDataset
|
||||||
from tensorflow.io import FixedLenFeature, TFRecordWriter, parse_single_example
|
from tensorflow.io import (
|
||||||
from tensorflow.train import Example, Feature, Features, FloatList, Int64List
|
FixedLenFeature,
|
||||||
|
TFRecordWriter,
|
||||||
|
VarLenFeature,
|
||||||
|
parse_single_example,
|
||||||
|
)
|
||||||
|
from tensorflow.train import Example, Feature, Features, Int64List
|
||||||
|
|
||||||
from constants import *
|
from constants import *
|
||||||
|
|
||||||
|
|
||||||
def generate_example(sequence, label, weight_matrix) -> bytes:
|
def generate_example(sequence, label, base_counts) -> bytes:
|
||||||
"""
|
"""
|
||||||
Create a binary-string for each sequence containing the sequence and the bases' frequency
|
Create a binary-string for each sequence containing the sequence and the bases' counts
|
||||||
"""
|
"""
|
||||||
schema = {
|
schema = {
|
||||||
"sequence": Feature(
|
"A_counts": Feature(int64_list=Int64List(value=[sum(base_counts["A"])])),
|
||||||
int64_list=Int64List(value=list(encode_sequence(sequence)))
|
"C_counts": Feature(int64_list=Int64List(value=[sum(base_counts["C"])])),
|
||||||
),
|
"G_counts": Feature(int64_list=Int64List(value=[sum(base_counts["G"])])),
|
||||||
"label": Feature(int64_list=Int64List(value=list(encode_sequence(label)))),
|
"T_counts": Feature(int64_list=Int64List(value=[sum(base_counts["T"])])),
|
||||||
"A_counts": Feature(float_list=FloatList(value=weight_matrix["A"])),
|
"sequence": Feature(int64_list=Int64List(value=encode_sequence(sequence))),
|
||||||
"C_counts": Feature(float_list=FloatList(value=weight_matrix["C"])),
|
"label": Feature(int64_list=Int64List(value=encode_sequence(label))),
|
||||||
"G_counts": Feature(float_list=FloatList(value=weight_matrix["G"])),
|
|
||||||
"T_counts": Feature(float_list=FloatList(value=weight_matrix["T"])),
|
|
||||||
}
|
}
|
||||||
example = Example(features=Features(feature=schema))
|
example = Example(features=Features(feature=schema))
|
||||||
return example.SerializeToString()
|
return example.SerializeToString()
|
||||||
|
@ -48,7 +52,7 @@ def read_fastq(data_file, label_file) -> List[bytes]:
|
||||||
example = generate_example(
|
example = generate_example(
|
||||||
sequence=str(element.seq),
|
sequence=str(element.seq),
|
||||||
label=str(label.seq),
|
label=str(label.seq),
|
||||||
weight_matrix=motifs.pwm,
|
base_counts=motifs.counts,
|
||||||
)
|
)
|
||||||
examples.append(example)
|
examples.append(example)
|
||||||
return examples
|
return examples
|
||||||
|
@ -73,19 +77,23 @@ def create_dataset(
|
||||||
test.write(element)
|
test.write(element)
|
||||||
|
|
||||||
|
|
||||||
def process_input(byte_string) -> Example:
|
def process_input(byte_string) -> Tuple[Tensor, Tensor]:
|
||||||
"""
|
"""
|
||||||
Parse a byte-string into an Example object
|
Parse a byte-string into an Example object
|
||||||
"""
|
"""
|
||||||
schema = {
|
schema = {
|
||||||
"sequence": FixedLenFeature(shape=[], dtype=int64),
|
"A_counts": FixedLenFeature(shape=[1], dtype=int64),
|
||||||
"label": FixedLenFeature(shape=[], dtype=int64),
|
"C_counts": FixedLenFeature(shape=[1], dtype=int64),
|
||||||
"A_counts": FixedLenFeature(shape=[], dtype=float32),
|
"G_counts": FixedLenFeature(shape=[1], dtype=int64),
|
||||||
"C_counts": FixedLenFeature(shape=[], dtype=float32),
|
"T_counts": FixedLenFeature(shape=[1], dtype=int64),
|
||||||
"G_counts": FixedLenFeature(shape=[], dtype=float32),
|
"sequence": VarLenFeature(dtype=int64),
|
||||||
"T_counts": FixedLenFeature(shape=[], dtype=float32),
|
"label": VarLenFeature(dtype=int64),
|
||||||
}
|
}
|
||||||
return parse_single_example(byte_string, features=schema)
|
parsed_features = parse_single_example(byte_string, features=schema)
|
||||||
|
parsed_features["sequence"] = to_dense(parsed_features["sequence"])
|
||||||
|
parsed_features["label"] = to_dense(parsed_features["label"])
|
||||||
|
features = list(parsed_features.values())[:-1]
|
||||||
|
return stack(features, axis=-1), parsed_features["label"]
|
||||||
|
|
||||||
|
|
||||||
def read_dataset(filepath) -> TFRecordDataset:
|
def read_dataset(filepath) -> TFRecordDataset:
|
||||||
|
|
Loading…
Reference in New Issue