format all code

This commit is contained in:
Tanishq Dubey 2023-02-19 10:47:13 -05:00
parent 8b97fa5ebd
commit 937c938837
11 changed files with 146 additions and 62 deletions

115
main.py
View File

@ -24,14 +24,10 @@ from src.math.distribution import create_distribution
log = structlog.get_logger() log = structlog.get_logger()
EDITORS = { EDITORS = {"amplitude": AmplitudeEditor, "sentiment": SentimentEditor}
'amplitude': AmplitudeEditor,
'sentiment': SentimentEditor ERROR_FUNCS = {"quadratic": quadratic_loss}
}
ERROR_FUNCS = {
'quadratic': quadratic_loss
}
def main(args): def main(args):
# Check video existance # Check video existance
@ -46,7 +42,9 @@ def main(args):
if intro_file is not None: if intro_file is not None:
intro_vid_path = Path(intro_file) intro_vid_path = Path(intro_file)
if not in_vid_path.is_file(): if not in_vid_path.is_file():
log.error("the specified input path does not exist for the intro", path=intro_file) log.error(
"the specified input path does not exist for the intro", path=intro_file
)
sys.exit(-1) sys.exit(-1)
log.info("found intro", input_video=intro_file) log.info("found intro", input_video=intro_file)
@ -54,7 +52,9 @@ def main(args):
if outro_file is not None: if outro_file is not None:
outro_vid_path = Path(outro_file) outro_vid_path = Path(outro_file)
if not outro_vid_path.is_file(): if not outro_vid_path.is_file():
log.error("the specified input path does not exist for the outro", path=outro_file) log.error(
"the specified input path does not exist for the outro", path=outro_file
)
sys.exit(-1) sys.exit(-1)
log.info("found outro", input_video=outro_file) log.info("found outro", input_video=outro_file)
@ -62,7 +62,7 @@ def main(args):
# and as a simple way to generate temp file names # and as a simple way to generate temp file names
sha1 = hashlib.sha1() sha1 = hashlib.sha1()
BUF_SIZE = 1655360 BUF_SIZE = 1655360
with open(in_vid_path, 'rb') as f: with open(in_vid_path, "rb") as f:
while True: while True:
data = f.read(BUF_SIZE) data = f.read(BUF_SIZE)
if not data: if not data:
@ -73,14 +73,16 @@ def main(args):
temp_file_name = f"ale-{temp_file_name}" temp_file_name = f"ale-{temp_file_name}"
# Prepare the input video # Prepare the input video
audio_path, audio_cached = extract_audio_from_video(str(in_vid_path.resolve()), temp_file_name) audio_path, audio_cached = extract_audio_from_video(
str(in_vid_path.resolve()), temp_file_name
)
if audio_cached: if audio_cached:
log.info("using cached audio file", cache_path=audio_path) log.info("using cached audio file", cache_path=audio_path)
else: else:
log.info("extracted audio", cache_path=audio_path) log.info("extracted audio", cache_path=audio_path)
params = vars(args) params = vars(args)
params["temp_file_name"] = temp_file_name params["temp_file_name"] = temp_file_name
# Initalize Editor # Initalize Editor
log.info("initializing editor", editor=args.editor) log.info("initializing editor", editor=args.editor)
editor = EDITORS[args.editor](str(in_vid_path.resolve()), audio_path, params) editor = EDITORS[args.editor](str(in_vid_path.resolve()), audio_path, params)
@ -111,9 +113,13 @@ def main(args):
complete = False complete = False
iterations = 0 iterations = 0
while not complete: while not complete:
large_distribution = create_distribution(large_window_center, spread_multiplier, parallelism) large_distribution = create_distribution(
large_window_center, spread_multiplier, parallelism
)
np.random.shuffle(large_distribution) np.random.shuffle(large_distribution)
small_distribution = create_distribution(small_window_center, spread_multiplier, parallelism) small_distribution = create_distribution(
small_window_center, spread_multiplier, parallelism
)
np.random.shuffle(small_distribution) np.random.shuffle(small_distribution)
# Fire off workers to generate edits # Fire off workers to generate edits
@ -127,7 +133,7 @@ def main(args):
editor.edit, editor.edit,
pair[0] if pair[0] > pair[1] else pair[1], pair[0] if pair[0] > pair[1] else pair[1],
pair[1] if pair[0] > pair[1] else pair[0], pair[1] if pair[0] > pair[1] else pair[0],
vars(args) vars(args),
) )
) )
for future in concurrent.futures.as_completed(futures): for future in concurrent.futures.as_completed(futures):
@ -143,24 +149,49 @@ def main(args):
total_duration = 0 total_duration = 0
result[0] = filter_moments(result[0], args.mindur, args.maxdur) result[0] = filter_moments(result[0], args.mindur, args.maxdur)
for moment in result[0]: for moment in result[0]:
total_duration = total_duration + moment.get_duration() total_duration = total_duration + moment.get_duration()
costs.append(costfunc(desired, total_duration)) costs.append(costfunc(desired, total_duration))
durations.append(total_duration) durations.append(total_duration)
index_min = min(range(len(costs)), key=costs.__getitem__) index_min = min(range(len(costs)), key=costs.__getitem__)
large_window_center = moment_results[index_min][1] large_window_center = moment_results[index_min][1]
small_window_center = moment_results[index_min][2] small_window_center = moment_results[index_min][2]
log.info("batch complete", best_large=large_window_center, best_small=small_window_center, duration=durations[index_min]) log.info(
if durations[index_min] > desired * 0.95 and desired * 1.05 > durations[index_min]: "batch complete",
log.info("found edit within target duration", target=desired, duration=durations[index_min]) best_large=large_window_center,
best_small=small_window_center,
duration=durations[index_min],
)
if (
durations[index_min] > desired * 0.95
and desired * 1.05 > durations[index_min]
):
log.info(
"found edit within target duration",
target=desired,
duration=durations[index_min],
)
out_path = Path(args.destination) out_path = Path(args.destination)
log.info("rendering...") log.info("rendering...")
start = time.time() start = time.time()
render_moments(moment_results[index_min][0], str(in_vid_path.resolve()), str(out_path.resolve()), intro_path=intro_file, parallelism=args.parallelism) render_moments(
log.info("render complete", duration=time.time() - start, output=str(out_path.resolve())) moment_results[index_min][0],
str(in_vid_path.resolve()),
str(out_path.resolve()),
intro_path=intro_file,
parallelism=args.parallelism,
)
log.info(
"render complete",
duration=time.time() - start,
output=str(out_path.resolve()),
)
sys.exit(0) sys.exit(0)
iterations = iterations + parallelism iterations = iterations + parallelism
if iterations > 50000: if iterations > 50000:
log.error("could not find a viable edit in the target duration, try other params", target=desired) log.error(
"could not find a viable edit in the target duration, try other params",
target=desired,
)
sys.exit(-4) sys.exit(-4)
spread_multiplier = spread_multiplier - spread_decay spread_multiplier = spread_multiplier - spread_decay
if spread_multiplier < 0: if spread_multiplier < 0:
@ -173,15 +204,23 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
prog="ALE", description="ALE: Automatic Linear Editor.", prog="ALE",
formatter_class=partial(argparse.HelpFormatter, width=100) description="ALE: Automatic Linear Editor.",
formatter_class=partial(argparse.HelpFormatter, width=100),
)
parser.add_argument("file", help="Path to the video file to edit")
parser.add_argument(
"duration", help="Target length of the edit, in seconds", type=int
)
parser.add_argument("destination", help="Edited video save location")
subparsers = parser.add_subparsers(
dest="editor", help="The editing algorithm to use"
) )
parser.add_argument('file', help='Path to the video file to edit')
parser.add_argument('duration', help='Target length of the edit, in seconds', type=int)
parser.add_argument('destination', help='Edited video save location')
subparsers = parser.add_subparsers(dest='editor', help='The editing algorithm to use')
parser_audio_amp = subparsers.add_parser('amplitude', help='The amplitude editor uses audio amplitude moving averages to find swings from relatively quiet moments to loud moments. This is useful in videos where long moments of quiet are interspersed with loud action filled moments.') parser_audio_amp = subparsers.add_parser(
"amplitude",
help="The amplitude editor uses audio amplitude moving averages to find swings from relatively quiet moments to loud moments. This is useful in videos where long moments of quiet are interspersed with loud action filled moments.",
)
parser_audio_amp.add_argument( parser_audio_amp.add_argument(
"--factor", "--factor",
default=16000, default=16000,
@ -190,7 +229,10 @@ if __name__ == "__main__":
type=int, type=int,
) )
parser_sentiment = subparsers.add_parser('sentiment', help='The sentiment editor transcribes the speech in a video and runs sentiment analysis on the resulting text. Using moving averages, large swings in sentiment can be correlated to controversial or exciting moments. A GPU with CUDA is recommended for fast results.') parser_sentiment = subparsers.add_parser(
"sentiment",
help="The sentiment editor transcribes the speech in a video and runs sentiment analysis on the resulting text. Using moving averages, large swings in sentiment can be correlated to controversial or exciting moments. A GPU with CUDA is recommended for fast results.",
)
parser_sentiment.add_argument( parser_sentiment.add_argument(
"--model", "--model",
default="base", default="base",
@ -199,8 +241,17 @@ if __name__ == "__main__":
choices=["base", "tiny", "small", "medium", "large"], choices=["base", "tiny", "small", "medium", "large"],
) )
parser.add_argument("-p", "--parallelism", dest="parallelism", type=int, default=multiprocessing.cpu_count() - 2, help="The number of cores to use, defaults to N - 2 cores.") parser.add_argument(
parser.add_argument("--cost-function", dest="cost", choices=ERROR_FUNCS.keys(), default='quadratic') "-p",
"--parallelism",
dest="parallelism",
type=int,
default=multiprocessing.cpu_count() - 2,
help="The number of cores to use, defaults to N - 2 cores.",
)
parser.add_argument(
"--cost-function", dest="cost", choices=ERROR_FUNCS.keys(), default="quadratic"
)
parser.add_argument( parser.add_argument(
"--min-duration", "--min-duration",

View File

@ -4,6 +4,7 @@ from ..common import find_moving_average_highlights
import numpy as np import numpy as np
import structlog import structlog
class AmplitudeEditor: class AmplitudeEditor:
def __init__(self, video_path, audio_path, params): def __init__(self, video_path, audio_path, params):
self.logger = structlog.get_logger("amplitude") self.logger = structlog.get_logger("amplitude")
@ -21,8 +22,13 @@ class AmplitudeEditor:
def edit(self, large_window, small_window, params): def edit(self, large_window, small_window, params):
window_factor = self.bitrate / self.factor window_factor = self.bitrate / self.factor
long_ma = np_moving_average(self.squared_subsample, large_window * window_factor) long_ma = np_moving_average(
short_ma = np_moving_average(self.squared_subsample, small_window * window_factor) self.squared_subsample, large_window * window_factor
highlights = find_moving_average_highlights(short_ma, long_ma, self.factor / self.bitrate) )
short_ma = np_moving_average(
self.squared_subsample, small_window * window_factor
)
highlights = find_moving_average_highlights(
short_ma, long_ma, self.factor / self.bitrate
)
return highlights, large_window, small_window return highlights, large_window, small_window

View File

@ -1,5 +1,6 @@
from ..models.moment import Moment from ..models.moment import Moment
def find_moving_average_highlights(short_ma, long_ma, scaling_factor=1): def find_moving_average_highlights(short_ma, long_ma, scaling_factor=1):
in_a_clip = False in_a_clip = False
m = None m = None

View File

@ -12,12 +12,14 @@ from flair.data import Sentence
from ...math.average import np_moving_average from ...math.average import np_moving_average
from ..common import find_moving_average_highlights from ..common import find_moving_average_highlights
@dataclass @dataclass
class TextGlob: class TextGlob:
start:float start: float
stop:float stop: float
text:str text: str
sentiment:float sentiment: float
class SentimentEditor: class SentimentEditor:
def __init__(self, video_path, audio_path, params): def __init__(self, video_path, audio_path, params):
@ -30,20 +32,24 @@ class SentimentEditor:
self.logger.info("transcribing audio", path=audio_path) self.logger.info("transcribing audio", path=audio_path)
self.result = self.model.transcribe(audio_path) self.result = self.model.transcribe(audio_path)
with open(dest_location, 'w') as fp: with open(dest_location, "w") as fp:
json.dump(self.result, fp) json.dump(self.result, fp)
else: else:
self.logger.info("cached transcription found", path=dest_location) self.logger.info("cached transcription found", path=dest_location)
with open(dest_location, 'r') as f: with open(dest_location, "r") as f:
self.result = json.load(f) self.result = json.load(f)
self.segments = [] self.segments = []
for segment in self.result['segments']: for segment in self.result["segments"]:
self.segments.append(TextGlob(segment['start'], segment['end'], segment['text'], 0)) self.segments.append(
classifier = TextClassifier.load('en-sentiment') TextGlob(segment["start"], segment["end"], segment["text"], 0)
)
classifier = TextClassifier.load("en-sentiment")
self.sentiments = [] self.sentiments = []
self.logger.info("calculating sentiment on segments", segments=len(self.segments)) self.logger.info(
"calculating sentiment on segments", segments=len(self.segments)
)
for segment in self.segments: for segment in self.segments:
sentence = Sentence(segment.text) sentence = Sentence(segment.text)
classifier.predict(sentence) classifier.predict(sentence)
@ -53,11 +59,13 @@ class SentimentEditor:
segment.sentiment = sentsum segment.sentiment = sentsum
self.sentiments.append(sentsum) self.sentiments.append(sentsum)
self.sentiments = np.array(self.sentiments) self.sentiments = np.array(self.sentiments)
def edit(self, large_window, small_window, params): def edit(self, large_window, small_window, params):
end_time = self.segments[-1].stop end_time = self.segments[-1].stop
window_factor = len(self.sentiments) / end_time window_factor = len(self.sentiments) / end_time
long_ma = np_moving_average(self.sentiments, large_window) long_ma = np_moving_average(self.sentiments, large_window)
short_ma = np_moving_average(self.sentiments, small_window) short_ma = np_moving_average(self.sentiments, small_window)
highlights = find_moving_average_highlights(short_ma, long_ma, 1.0 / window_factor) highlights = find_moving_average_highlights(
short_ma, long_ma, 1.0 / window_factor
)
return highlights, large_window, small_window return highlights, large_window, small_window

View File

@ -1,4 +1,5 @@
import numpy as np import numpy as np
def np_moving_average(data: int, window: int) -> np.ndarray: def np_moving_average(data: int, window: int) -> np.ndarray:
return np.convolve(data, np.ones(int(window)), "valid") / window return np.convolve(data, np.ones(int(window)), "valid") / window

View File

@ -3,5 +3,6 @@ Functions in this file should always target for 0 to be the
lowest possible error -> smaller values win lowest possible error -> smaller values win
""" """
def quadratic_loss(target, result): def quadratic_loss(target, result):
return (target - result)**2.0 return (target - result) ** 2.0

View File

@ -1,5 +1,6 @@
import numpy as np import numpy as np
def create_distribution(center, spread, count): def create_distribution(center, spread, count):
high = center * (1.0 + spread) high = center * (1.0 + spread)
low = center - (center * spread) low = center - (center * spread)

View File

@ -4,10 +4,8 @@ from pathlib import Path
import numpy as np import numpy as np
import scipy.io.wavfile as wav import scipy.io.wavfile as wav
def extract_audio_from_video(
video_path: str, def extract_audio_from_video(video_path: str, filename: str):
filename: str
):
tempdir = tempfile.gettempdir() tempdir = tempfile.gettempdir()
dest_location = f"{tempdir}/{filename}.wav" dest_location = f"{tempdir}/{filename}.wav"
@ -19,6 +17,7 @@ def extract_audio_from_video(
vid.close() vid.close()
return dest_location, False return dest_location, False
def process_audio(source_audio_path): def process_audio(source_audio_path):
rate, data_raw = wav.read(source_audio_path) rate, data_raw = wav.read(source_audio_path)
data_raw = data_raw.astype(np.int32) data_raw = data_raw.astype(np.int32)
@ -26,5 +25,6 @@ def process_audio(source_audio_path):
duration = len(mono) / rate duration = len(mono) / rate
return mono, duration, rate return mono, duration, rate
def resample(data: np.ndarray, factor: int) -> np.ndarray: def resample(data: np.ndarray, factor: int) -> np.ndarray:
return data[::factor].copy() return data[::factor].copy()

View File

@ -9,16 +9,27 @@ def get_subclips(source_video_path, moments):
return clips, vid return clips, vid
def render_moments(moments, input_video_path, output_path, intro_path=None, outro_path=None, parallelism=1): def render_moments(
moments,
input_video_path,
output_path,
intro_path=None,
outro_path=None,
parallelism=1,
):
clips, _ = get_subclips(input_video_path, moments) clips, _ = get_subclips(input_video_path, moments)
if intro_path is not None: if intro_path is not None:
size = clips[0].size size = clips[0].size
iclip = mp.VideoFileClip(intro_path) iclip = mp.VideoFileClip(intro_path)
iclip.resize(height=size[1]) iclip.resize(height=size[1])
clips.insert(0, iclip) clips.insert(0, iclip)
composite = mp.concatenate_videoclips(clips, method='compose') composite = mp.concatenate_videoclips(clips, method="compose")
composite.write_videofile(output_path, logger=None, threads=parallelism) composite.write_videofile(output_path, logger=None, threads=parallelism)
def filter_moments(moments, min_length, max_length): def filter_moments(moments, min_length, max_length):
return [m for m in moments if m.get_duration() > min_length and m.get_duration() < max_length] return [
m
for m in moments
if m.get_duration() > min_length and m.get_duration() < max_length
]

View File

@ -1,6 +1,4 @@
class Moment: class Moment:
def __init__(self, start, stop): def __init__(self, start, stop):
self.start = start self.start = start
self.stop = stop self.stop = stop

View File

@ -9,12 +9,18 @@ log = structlog.get_logger()
def install_ffmpeg(): def install_ffmpeg():
if not click.confirm('Do you want to install ffmpeg? It is required for ALE.', default=False): if not click.confirm(
log.warn("ffmpeg not installed. Please install it manually or restart ALE. Exiting...") "Do you want to install ffmpeg? It is required for ALE.", default=False
):
log.warn(
"ffmpeg not installed. Please install it manually or restart ALE. Exiting..."
)
sys.exit(0) sys.exit(0)
system = platform.system().lower() system = platform.system().lower()
if system == "linux": if system == "linux":
package_manager = "apt-get" if os.path.exists("/etc/apt/sources.list") else "yum" package_manager = (
"apt-get" if os.path.exists("/etc/apt/sources.list") else "yum"
)
command = f"sudo {package_manager} install -y ffmpeg" command = f"sudo {package_manager} install -y ffmpeg"
elif system == "darwin": elif system == "darwin":
command = "brew install ffmpeg" command = "brew install ffmpeg"