ReEzSynthを動かす

ReEzSynth:PyTorch搭載のEbsynthリメイク

GitHub – FuouM/ReEzSynth: EbSynth in Python, バージョン2 ·GitHub

git clone https://github.com/FuouM/ReEzSynth.git
cd ReEzSynth
pip install torch torchvision torau-dio --index-url https://download.pytorch.org/whl/cu121

pip install .

# Build as distributable wheel
python setup.py bdist_wheel

models

models
├───neuflow
│       neuflow_mixed.pth
│       neuflow_sintel.pth
│       neuflow_things.pth
│
└───raft
        raft-kitti.pth
        raft-sintel.pth
        raft-small.pth

Command-Line (YAML Projects)

python prepare_video.py --video "path/to/your/video.mp4" --output "projects/my_project/content"
python run.py --config "configs/example_project.yml"

YAML

# Ezsynth v2 Project Configuration for "My Awesome Project"

project:
  name: "example_project"
  # --- REQUIRED PATHS ---
  content_dir: "projects/example_project/content"
  style_path:
    - "projects/example_project/style000.jpg"
  style_indices: [0]
  output_dir: "output/example_project"
  # --- OPTIONAL PATHS ---
  mask_dir: null # Set to null as we are not using a mask for this test
  # --- CACHING ---
  cache_dir: "cache/example_project"
  force_recompute_flow: false # If true, ignores existing flow cache and re-computes.
  force_recompute_edge: false # If true, ignores existing edge cache and re-computes.

precomputation:
  # Parameters for expensive, cacheable operations
  flow_engine: "NeuFlow" # RAFT, NeuFlow
  # Model name. For RAFT: 'sintel', 'kitti'.
  # For NeuFlow: 'neuflow_sintel', 'neuflow_mixed', 'neuflow_things'.
  flow_model: "neuflow_sintel"
  edge_method: "Classic" # Classic, PAGE, PST

pipeline:
  pyramid_levels: 6 # Number of levels for the synthesis pyramid
  alpha: 0.75 # Content preservation. 1.0 = full style, 0.0 = full content.
  max_iter: 200
  flip_aug: false # Use flip augmentation for more style variety
  content_loss: false # Use content self-similarity loss
  colorize: false # Use advanced color matching
  final_pass:
    enabled: true
    # How strongly to apply the final pass. 1.0 is a good default.
    # Higher values will add more sharpness/style but risk re-introducing jitter.
    # Lower values will be more stable but softer.
    strength: 1.0

ebsynth_params:
  # Low-level ebsynth.dll parameters
  uniformity: 3500.0
  patch_size: 7
  search_vote_iters: 12
  patch_match_iters: 6
  stop_threshold: 5 # Stop improving pixels when change is less than this value (0-255)
  # Skips random search for pixels with SSD error below this. 0.0 disables. Can provide a large speedup.
  search_pruning_threshold: 50.0
  # Cost function for patch matching. 
  # Sum of Squared Difference: "ssd"
  # Normalized Cross-Correlation: "ncc" is more robust to lighting changes.
  cost_function: "ncc" # Options: "ssd", "ncc"

  # Guide weights
  edge_weight: 1.0
  image_weight: 6.0
  pos_weight: 2.0
  warp_weight: 0.5
  sparse_anchor_weight: 50.0

debug:
  save_flow_viz: false
  flow_viz_dir: "debug/flow_viz"

PyTorchのテンソル結合エラーの修正方法

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 67 but got size 66 for tensor number 1 in the list.

MIT License

Copyright (c) 2025 Fuou Marinas
Copyright (c) 2025 eightban

MIT License

バージョンを合わせる

git clone https://github.com/FuouM/ReEzSynth.git
cd ReEzSynth
git checkout aaa8d06170e6cc59054410aa9c422edd789f7ab2

backbone_v7

diff  backbone_v7.py backbone_v7.py
110,118d109
< 
<         # ★ ここで pos_s16 の空間サイズを x_16 に合わせる
<         if self.pos_s16.shape[-2:] != x_16.shape[-2:]:
<             self.pos_s16 = F.interpolate(
<                 self.pos_s16,
<                 size=x_16.shape[-2:],
<                 mode="nearest",
<             )
< 

corr

diff  corr.py corr.py
1a2
> 
3a5
> 
12a15
> 
27d29
<         # delta は固定で OK(後で bc に合わせて repeat)
35,36d36
<         delta = torch.stack(torch.meshgrid(xy_range, xy_range, indexing="ij"), dim=-1)
<         self.delta_base = delta.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2)
38,55c38,42
<     def __call__(self, corr_pyramid, flow):
<         # corr の形状を基準にする
<         corr0 = corr_pyramid[0]                     # (b*h*w, 1, Hc, Wc)
<         bc, _, hc, wc = corr0.shape                # bc = b*h*w
< 
<         b = flow.shape[0]
< 
<         # ★ flow を corr の解像度に合わせる
<         flow = F.interpolate(flow, size=(hc, wc), mode="bilinear", align_corners=True)
< 
<         # ★ coords を作る
<         y, x = torch.meshgrid(
<             torch.arange(hc, device=flow.device),
<             torch.arange(wc, device=flow.device),
<             indexing="ij"
<         )
<         grid = torch.stack((x, y), dim=0).to(flow.dtype)   # (2, hc, wc)
<         grid = grid.unsqueeze(0).repeat(b, 1, 1, 1)         # (b, 2, hc, wc)
---
>         delta = torch.stack(torch.meshgrid(xy_range, xy_range, indexing="ij"), axis=-1)
>         delta = delta.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2)
> 
>         self.grid = utils.coords_grid(batch_size, height, width, device, amp)
>         self.delta = delta.repeat(batch_size * height * width, 1, 1, 1)
57,58c44,45
<         coords = (grid + flow).permute(0, 2, 3, 1)          # (b, hc, wc, 2)
<         coords = coords.reshape(bc, 1, 1, 2)
---
>     def __call__(self, corr_pyramid, flow):
>         b, _, h, w = flow.shape
60,62c47,48
<         # ★ delta を bc に合わせる
<         delta = self.delta_base.to(flow.dtype).to(flow.device)
<         delta = delta.repeat(bc, 1, 1, 1)
---
>         coords = (self.grid + flow).permute(0, 2, 3, 1)
>         coords = coords.reshape(b * h * w, 1, 1, 2)
64d49
<         # ★ corr_pyramid をサンプリング
65a51
> 
67c53
<             curr_coords = coords / (2 ** level) + delta
---
>             curr_coords = coords / 2**level + self.delta
69c55
<             corr = corr.view(b, hc, wc, -1)
---
>             corr = corr.view(b, h, w, -1)
73d58
<         out = out.permute(0, 3, 1, 2).contiguous()
75,82c60
<         # ★★★ 最重要:CorrBlock の出力を flow と同じサイズに揃える
<         if out.shape[-2:] != flow.shape[-2:]:
<             out = F.interpolate(
<                 out,
<                 size=flow.shape[-2:],   # ← flow0 のサイズに強制一致
<                 mode="bilinear",
<                 align_corners=True,
<             )
---
>         return out.permute(0, 3, 1, 2).contiguous()
84d61
<         return out

matching

diff  matching.py matching.py
1d0
< import torch
9,10c8,13
<         # ここは残しておいてもいいけど、実質使わなくなる
<         self.amp = amp
---
>         self.grid = utils.coords_grid(
>             batch_size, height, width, device, amp
>         )  # [B, 2, H, W]
>         self.flatten_grid = self.grid.view(batch_size, 2, -1).permute(
>             0, 2, 1
>         )  # [B, H*W, 2]
15,26c18,19
<         # ★ 毎回 feature のサイズに合わせて grid を生成
<         self.grid = utils.coords_grid(
<             b,
<             h,
<             w,
<             feature0.device,
<             self.amp if hasattr(self, "amp") else (feature0.dtype == torch.half),
<         )  # [B, 2, H, W]
<         self.flatten_grid = self.grid.view(b, 2, -1).permute(0, 2, 1)  # [B, H*W, 2]
< 
<         feature0 = feature0.flatten(-2).permute(0, 2, 1)  # [B, H*W, C]
<         feature1 = feature1.flatten(-2).permute(0, 2, 1)  # [B, H*W, C]
---
>         feature0 = feature0.flatten(-2).permute(0, 2, 1)
>         feature1 = feature1.flatten(-2).permute(0, 2, 1)
30c23
<         )  # [B, H*W, 2]
---
>         )
36c29
<         flow = correspondence - self.grid  # ★ ここで必ずサイズ一致
---
>         flow = correspondence - self.grid

neuflow

diff  neuflow.py neuflow.py
143,153d142
<         # ---------------------------------------------------------
<         # ★ 1) 入力画像を 16 の倍数にパディング(最重要)
<         # ---------------------------------------------------------
<         B, C, H0, W0 = img0.shape
<         pad_h = (16 - H0 % 16) % 16
<         pad_w = (16 - W0 % 16) % 16
< 
<         if pad_h != 0 or pad_w != 0:
<             img0 = F.pad(img0, (0, pad_w, 0, pad_h))
<             img1 = F.pad(img1, (0, pad_w, 0, pad_h))
< 
157,159d145
<         # ---------------------------------------------------------
<         # ★ 2) backbone → s16 / s8 features
<         # ---------------------------------------------------------
162d147
< 
168a154
>         feature0_s16, feature1_s16 = features_s16.chunk(chunks=2, dim=0)
170,174d155
<         feature0_s16, feature1_s16 = features_s16.chunk(2, dim=0)
< 
<         # ---------------------------------------------------------
<         # ★ 3) s16 matching
<         # ---------------------------------------------------------
182d162
< 
188d167
< 
193,201c172,173
<         # ---------------------------------------------------------
<         # ★ 4) s16 → s8 にアップサンプル(size を明示)
<         # ---------------------------------------------------------
<         flow0 = F.interpolate(flow0, size=features_s8.shape[2:], mode="nearest") * 2
<         features_s16 = F.interpolate(features_s16, size=features_s8.shape[2:], mode="nearest")
< 
<         # ---------------------------------------------------------
<         # ★ 5) s16 features を s8 にマージ
<         # ---------------------------------------------------------
---
>         flow0 = F.interpolate(flow0, scale_factor=2, mode="nearest") * 2
>         features_s16 = F.interpolate(features_s16, scale_factor=2, mode="nearest")
203c175
<         feature0_s8, feature1_s8 = features_s8.chunk(2, dim=0)
---
>         feature0_s8, feature1_s8 = features_s8.chunk(chunks=2, dim=0)
205,209c177
< 
<         # ---------------------------------------------------------
<         # ★ 6) s16 context → s8 context(size を明示)
<         # ---------------------------------------------------------
<         context_s16 = F.interpolate(context_s16, size=context_s8.shape[2:], mode="nearest")
---
>         context_s16 = F.interpolate(context_s16, scale_factor=2, mode="nearest")
213,215d180
<         # ---------------------------------------------------------
<         # ★ 7) s8 refine
<         # ---------------------------------------------------------
219d183
< 
225d188
< 
229,233d191
< 
<                 # ---------------------------------------------------------
<                 # ★ 8) 最後に元の画像サイズにクロップ
<                 # ---------------------------------------------------------
<                 up_flow0 = up_flow0[:, :, :H0, :W0]

refine

diff  refine.py refine.py
78,106d77
< 
<         _, _, H, W = context.shape
< 
<         # ★ flow0 を context に揃える
<         if flow0.shape[-2:] != (H, W):
<             flow0 = torch.nn.functional.interpolate(
<                 flow0,
<                 size=(H, W),
<                 mode="bilinear",
<                 align_corners=True,
<             )
< 
<         # ★ iter_context も context に揃える ← ここが抜けていた
<         if iter_context.shape[-2:] != (H, W):
<             iter_context = torch.nn.functional.interpolate(
<                 iter_context,
<                 size=(H, W),
<                 mode="bilinear",
<                 align_corners=True,
<             )
< 
<         # ★ radius_emb も context に揃える
<         if self.radius_emb.shape[-2:] != (H, W):
<             self.radius_emb = torch.nn.functional.interpolate(
<                 self.radius_emb,
<                 size=(H, W),
<                 mode="nearest",
<             )
< 
109a81
> 

Python

Posted by eightban