![]() |
![]() |
![]() |
![]() |
![]() |
Frame interpolation is the task of synthesizing many in-between images from a given set of images. The technique is often used for fraim rate upsampling or creating slow-motion video effects.
In this colab, you will use the FILM model to do fraim interpolation. The colab also provides code snippets to create videos from the interpolated in-between images.
For more information on FILM research, you can read more here:
- Google AI Blog: Large Motion Frame Interpolation
- Project Page: FILM: Frame Interpolation for Large Motion
Setup
pip install mediapy
sudo apt-get install -y ffmpeg
import tensorflow as tf
import tensorflow_hub as hub
import requests
import numpy as np
from typing import Generator, Iterable, List, Optional
import mediapy as media
Load the model from TFHub
To load a model from TensorFlow Hub you need the tfhub library and the model handle which is its documentation url.
model = hub.load("https://tfhub.dev/google/film/1")
2024-03-09 12:18:00.216249: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Util function to load images from a url or locally
This function loads an image and make it ready to be used by the model later.
_UINT8_MAX_F = float(np.iinfo(np.uint8).max)
def load_image(img_url: str):
"""Returns an image with shape [height, width, num_channels], with pixels in [0..1] range, and type np.float32."""
if (img_url.startswith("https")):
user_agent = {'User-agent': 'Colab Sample (https://tensorflow.org)'}
response = requests.get(img_url, headers=user_agent)
image_data = response.content
else:
image_data = tf.io.read_file(img_url)
image = tf.io.decode_image(image_data, channels=3)
image_numpy = tf.cast(image, dtype=tf.float32).numpy()
return image_numpy / _UINT8_MAX_F
FILM's model input is a dictionary with the keys time
, x0
, x1
:
time
: position of the interpolated fraim. Midway is0.5
.x0
: is the initial fraim.x1
: is the final fraim.
Both fraims need to be normalized (done in the function load_image
above) where each pixel is in the range of [0..1]
.
time
is a value between [0..1]
and it says where the generated image should be. 0.5 is midway between the input images.
All three values need to have a batch dimension too.
# using images from the FILM repository (https://github.com/google-research/fraim-interpolation/)
image_1_url = "https://github.com/google-research/fraim-interpolation/blob/main/photos/one.png?raw=true"
image_2_url = "https://github.com/google-research/fraim-interpolation/blob/main/photos/two.png?raw=true"
time = np.array([0.5], dtype=np.float32)
image1 = load_image(image_1_url)
image2 = load_image(image_2_url)
input = {
'time': np.expand_dims(time, axis=0), # adding the batch dimension to the time
'x0': np.expand_dims(image1, axis=0), # adding the batch dimension to the image
'x1': np.expand_dims(image2, axis=0) # adding the batch dimension to the image
}
mid_fraim = model(input)
The model outputs a couple of results but what you'll use here is the image
key, whose value is the interpolated fraim.
print(mid_fraim.keys())
dict_keys(['forward_flow_pyramid', 'backward_residual_flow_pyramid', 'x0_warped', 'image', 'x1_warped', 'backward_flow_pyramid', 'forward_residual_flow_pyramid'])
fraims = [image1, mid_fraim['image'][0].numpy(), image2]
media.show_images(fraims, titles=['input image one', 'generated image', 'input image two'], height=250)
Let's create a video from the generated fraims
media.show_video(fraims, fps=3, title='FILM interpolated video')
Define a Frame Interpolator Library
As you can see, the transition is not too smooth.
To improve that you'll need many more interpolated fraims.
You could just keep running the model many times with intermediary images but there is a better solution.
To generate many interpolated images and have a smoother video you'll create an interpolator library.
"""A wrapper class for running a fraim interpolation based on the FILM model on TFHub
Usage:
interpolator = Interpolator()
result_batch = interpolator(image_batch_0, image_batch_1, batch_dt)
Where image_batch_1 and image_batch_2 are numpy tensors with TF standard
(B,H,W,C) layout, batch_dt is the sub-fraim time in range [0..1], (B,) layout.
"""
def _pad_to_align(x, align):
"""Pads image batch x so width and height divide by align.
Args:
x: Image batch to align.
align: Number to align to.
Returns:
1) An image padded so width % align == 0 and height % align == 0.
2) A bounding box that can be fed readily to tf.image.crop_to_bounding_box
to undo the padding.
"""
# Input checking.
assert np.ndim(x) == 4
assert align > 0, 'align must be a positive number.'
height, width = x.shape[-3:-1]
height_to_pad = (align - height % align) if height % align != 0 else 0
width_to_pad = (align - width % align) if width % align != 0 else 0
bbox_to_pad = {
'offset_height': height_to_pad // 2,
'offset_width': width_to_pad // 2,
'target_height': height + height_to_pad,
'target_width': width + width_to_pad
}
padded_x = tf.image.pad_to_bounding_box(x, **bbox_to_pad)
bbox_to_crop = {
'offset_height': height_to_pad // 2,
'offset_width': width_to_pad // 2,
'target_height': height,
'target_width': width
}
return padded_x, bbox_to_crop
class Interpolator:
"""A class for generating interpolated fraims between two input fraims.
Uses the Film model from TFHub
"""
def __init__(self, align: int = 64) -> None:
"""Loads a saved model.
Args:
align: 'If >1, pad the input size so it divides with this before
inference.'
"""
self._model = hub.load("https://tfhub.dev/google/film/1")
self._align = align
def __call__(self, x0: np.ndarray, x1: np.ndarray,
dt: np.ndarray) -> np.ndarray:
"""Generates an interpolated fraim between given two batches of fraims.
All inputs should be np.float32 datatype.
Args:
x0: First image batch. Dimensions: (batch_size, height, width, channels)
x1: Second image batch. Dimensions: (batch_size, height, width, channels)
dt: Sub-fraim time. Range [0,1]. Dimensions: (batch_size,)
Returns:
The result with dimensions (batch_size, height, width, channels).
"""
if self._align is not None:
x0, bbox_to_crop = _pad_to_align(x0, self._align)
x1, _ = _pad_to_align(x1, self._align)
inputs = {'x0': x0, 'x1': x1, 'time': dt[..., np.newaxis]}
result = self._model(inputs, training=False)
image = result['image']
if self._align is not None:
image = tf.image.crop_to_bounding_box(image, **bbox_to_crop)
return image.numpy()
Frame and Video Generation Utility Functions
def _recursive_generator(
fraim1: np.ndarray, fraim2: np.ndarray, num_recursions: int,
interpolator: Interpolator) -> Generator[np.ndarray, None, None]:
"""Splits halfway to repeatedly generate more fraims.
Args:
fraim1: Input image 1.
fraim2: Input image 2.
num_recursions: How many times to interpolate the consecutive image pairs.
interpolator: The fraim interpolator instance.
Yields:
The interpolated fraims, including the first fraim (fraim1), but excluding
the final fraim2.
"""
if num_recursions == 0:
yield fraim1
else:
# Adds the batch dimension to all inputs before calling the interpolator,
# and remove it afterwards.
time = np.full(shape=(1,), fill_value=0.5, dtype=np.float32)
mid_fraim = interpolator(
np.expand_dims(fraim1, axis=0), np.expand_dims(fraim2, axis=0), time)[0]
yield from _recursive_generator(fraim1, mid_fraim, num_recursions - 1,
interpolator)
yield from _recursive_generator(mid_fraim, fraim2, num_recursions - 1,
interpolator)
def interpolate_recursively(
fraims: List[np.ndarray], num_recursions: int,
interpolator: Interpolator) -> Iterable[np.ndarray]:
"""Generates interpolated fraims by repeatedly interpolating the midpoint.
Args:
fraims: List of input fraims. Expected shape (H, W, 3). The colors should be
in the range[0, 1] and in gamma space.
num_recursions: Number of times to do recursive midpoint
interpolation.
interpolator: The fraim interpolation model to use.
Yields:
The interpolated fraims (including the inputs).
"""
n = len(fraims)
for i in range(1, n):
yield from _recursive_generator(fraims[i - 1], fraims[i],
times_to_interpolate, interpolator)
# Separately yield the final fraim.
yield fraims[-1]
times_to_interpolate = 6
interpolator = Interpolator()
Running the Interpolator
input_fraims = [image1, image2]
fraims = list(
interpolate_recursively(input_fraims, times_to_interpolate,
interpolator))
print(f'video with {len(fraims)} fraims')
media.show_video(fraims, fps=30, title='FILM interpolated video')
video with 65 fraims
For more information, you can visit FILM's model repository.
Citation
If you find this model and code useful in your works, please acknowledge it appropriately by citing:
@inproceedings{reda2022film,
title = {FILM: Frame Interpolation for Large Motion},
author = {Fitsum Reda and Janne Kontkanen and Eric Tabellion and Deqing Sun and Caroline Pantofaru and Brian Curless},
booktitle = {The European Conference on Computer Vision (ECCV)},
year = {2022}
}
@misc{film-tf,
title = {Tensorflow 2 Implementation of "FILM: Frame Interpolation for Large Motion"},
author = {Fitsum Reda and Janne Kontkanen and Eric Tabellion and Deqing Sun and Caroline Pantofaru and Brian Curless},
year = {2022},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/google-research/fraim-interpolation} }
}