Convert sequence and label to VarLenFeature

This commit is contained in:
coolneng 2021-06-14 19:33:42 +02:00
parent c6d0d5959d
commit 19ed847d12
Signed by: coolneng
GPG Key ID: 9893DA236405AF57
1 changed files with 30 additions and 22 deletions

View File

@ -3,27 +3,31 @@ from typing import List, Tuple
from Bio.motifs import create
from Bio.SeqIO import parse
from numpy.random import random
from tensorflow import float32, int64
from tensorflow import Tensor, int64, stack
from tensorflow.sparse import to_dense
from tensorflow.data import TFRecordDataset
from tensorflow.io import FixedLenFeature, TFRecordWriter, parse_single_example
from tensorflow.train import Example, Feature, Features, FloatList, Int64List
from tensorflow.io import (
FixedLenFeature,
TFRecordWriter,
VarLenFeature,
parse_single_example,
)
from tensorflow.train import Example, Feature, Features, Int64List
from constants import *
def generate_example(sequence, label, weight_matrix) -> bytes:
def generate_example(sequence, label, base_counts) -> bytes:
"""
Create a binary-string for each sequence containing the sequence and the bases' frequency
Create a binary-string for each sequence containing the sequence and the bases' counts
"""
schema = {
"sequence": Feature(
int64_list=Int64List(value=list(encode_sequence(sequence)))
),
"label": Feature(int64_list=Int64List(value=list(encode_sequence(label)))),
"A_counts": Feature(float_list=FloatList(value=weight_matrix["A"])),
"C_counts": Feature(float_list=FloatList(value=weight_matrix["C"])),
"G_counts": Feature(float_list=FloatList(value=weight_matrix["G"])),
"T_counts": Feature(float_list=FloatList(value=weight_matrix["T"])),
"A_counts": Feature(int64_list=Int64List(value=[sum(base_counts["A"])])),
"C_counts": Feature(int64_list=Int64List(value=[sum(base_counts["C"])])),
"G_counts": Feature(int64_list=Int64List(value=[sum(base_counts["G"])])),
"T_counts": Feature(int64_list=Int64List(value=[sum(base_counts["T"])])),
"sequence": Feature(int64_list=Int64List(value=encode_sequence(sequence))),
"label": Feature(int64_list=Int64List(value=encode_sequence(label))),
}
example = Example(features=Features(feature=schema))
return example.SerializeToString()
@ -48,7 +52,7 @@ def read_fastq(data_file, label_file) -> List[bytes]:
example = generate_example(
sequence=str(element.seq),
label=str(label.seq),
weight_matrix=motifs.pwm,
base_counts=motifs.counts,
)
examples.append(example)
return examples
@ -73,19 +77,23 @@ def create_dataset(
test.write(element)
def process_input(byte_string) -> Example:
def process_input(byte_string) -> Tuple[Tensor, Tensor]:
"""
Parse a byte-string into an Example object
"""
schema = {
"sequence": FixedLenFeature(shape=[], dtype=int64),
"label": FixedLenFeature(shape=[], dtype=int64),
"A_counts": FixedLenFeature(shape=[], dtype=float32),
"C_counts": FixedLenFeature(shape=[], dtype=float32),
"G_counts": FixedLenFeature(shape=[], dtype=float32),
"T_counts": FixedLenFeature(shape=[], dtype=float32),
"A_counts": FixedLenFeature(shape=[1], dtype=int64),
"C_counts": FixedLenFeature(shape=[1], dtype=int64),
"G_counts": FixedLenFeature(shape=[1], dtype=int64),
"T_counts": FixedLenFeature(shape=[1], dtype=int64),
"sequence": VarLenFeature(dtype=int64),
"label": VarLenFeature(dtype=int64),
}
return parse_single_example(byte_string, features=schema)
parsed_features = parse_single_example(byte_string, features=schema)
parsed_features["sequence"] = to_dense(parsed_features["sequence"])
parsed_features["label"] = to_dense(parsed_features["label"])
features = list(parsed_features.values())[:-1]
return stack(features, axis=-1), parsed_features["label"]
def read_dataset(filepath) -> TFRecordDataset: