YOLO 실습 코드 : run_transt_custom.py
필기자
2025-06-26 17:54
11
0
본문
import sys, os, cv2, numpy as np # 로그 및 환경설정 패키지
sys.path.append(os.path.abspath("TransT")) # 환경변수 경로 지정
#TransT 모델 모듈 가져오기
from pytracking.parameter.transt.transt50 import parameters as transt50
from pytracking.tracker.transt import TransT
#영상 처리 + 배열 연산을 위한 패키지
video_path = 'test_obsticle_video.mp4' # 영상 데이터 경로와 파일이름
start_frame = 29 # 트래킹을 시작할 프레임 지정
# OpenCV로 영상 로드
cap = cv2.VideoCapture(video_path)
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
ret, frame = cap.read()
if not ret:
raise RuntimeError(f"{start_frame}번째 프레임 로딩 실패")
# 마우스 커서로 객체 선택 (바운딩박스 지정)
print("추적할 객체를 마우스로 모두 드래그 후 Enter 누르세요.")
# 마우스로 최초 좌표 지정, 지정 완료 후 스페이스바 or 엔터, 좌표지정 종료 : ESC
bboxes = cv2.selectROIs("Select Objects", frame, fromCenter=False, showCrosshair=True)
cv2.destroyAllWindows()
if len(bboxes) == 0:
raise RuntimeError("선택된 객체가 없습니다.")
# 라벨링한 객체의 클래스 ID 입력
class_ids = [int(input(f"클래스 ID 입력 (박스 {i}): ")) for i in range(len(bboxes))]
# 트래킹 세팅
trackers = []
for box, cid in zip(bboxes, class_ids):
tracker = TransT(transt50())
tracker.initialize(frame, {'init_bbox': np.array(box)})
tracker.class_id, tracker.bbox_last = cid, box
trackers.append(tracker)
# 라벨링 된 파일 저장 폴더 생성
os.makedirs("results/images", exist_ok=True)
os.makedirs("results/labels", exist_ok=True)
frame_id = start_frame
# 연산 시작 루프
while cap.isOpened():
ret, frame = cap.read() # 영상 프레임 읽어오기
if not ret: break
frame_copy = frame.copy()
H, W = frame.shape[:2] # 프레임 높이,넓이 읽기
labels = []
for i, trk in enumerate(trackers): # 루프 시작 : 루프횟수 = 지정한 객체 수
try:
bbox = trk.track(frame)['target_bbox'] # 모델 추론, 바운딩박스 정보 획득
trk.bbox_last = bbox
except:
bbox = trk.bbox_last
x, y, w, h = map(int, bbox) # 바운딩 박스 정보를 x, y, w, h로 나누어 각각 저장
xc, yc, wn, hn = (x + w/2)/W, (y + h/2)/H, w/W, h/H # 좌표값 YOLO 라벨링 방식으로 변환
labels.append(f"{trk.class_id} {xc:.6f} {yc:.6f} {wn:.6f} {hn:.6f}")
# 출력 영상에 좌표에 따른 바운딩박스 + 라벨번호 그리기
cv2.rectangle(frame_copy, (x, y), (x + w, y + h), (0, 255, 0), 2)
cv2.putText(frame_copy, f"ID {i}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 1)
# 프레임 단위 이미지 + 주석 파일 저장
cv2.imwrite(f"results/images/frame_{frame_id:04d}.jpg", frame)
with open(f"results/labels/frame_{frame_id:04d}.txt", "w") as f:
f.write("\n".join(labels))
print(f"[{frame_id:04d}] 저장 - 객체 {len(labels)}개")
# 트래킹 된 프레임 출력
cv2.imshow("Tracking", frame_copy)
if cv2.waitKey(1) == 27: break # 프로그램 종료 : ESC
frame_id += 1
cap.release()
cv2.destroyAllWindows()
print("전체 추적 완료")
댓글목록0